Skip to content

Commit

Permalink
[refactor] Remove dependencies on Program::this_thread_config() in Ke…
Browse files Browse the repository at this point in the history
…rnelCodeGen (#7086)

Issue: #7002

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PGZXB and pre-commit-ci[bot] committed Jan 12, 2023
1 parent db8c1f6 commit e3d8f73
Show file tree
Hide file tree
Showing 19 changed files with 68 additions and 37 deletions.
3 changes: 2 additions & 1 deletion taichi/codegen/amdgpu/codegen_amdgpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ namespace lang {

class KernelCodeGenAMDGPU : public KernelCodeGen {
public:
KernelCodeGenAMDGPU(Kernel *kernel) : KernelCodeGen(kernel) {
KernelCodeGenAMDGPU(const CompileConfig *config, Kernel *kernel)
: KernelCodeGen(config, kernel) {
}

// TODO: Stop defining this macro guards in the headers
Expand Down
23 changes: 13 additions & 10 deletions taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,30 @@

namespace taichi::lang {

KernelCodeGen::KernelCodeGen(Kernel *kernel)
: prog(kernel->program), kernel(kernel) {
KernelCodeGen::KernelCodeGen(const CompileConfig *compile_config,
Kernel *kernel)
: prog(kernel->program), kernel(kernel), compile_config_(compile_config) {
this->ir = kernel->ir.get();
}

std::unique_ptr<KernelCodeGen> KernelCodeGen::create(Arch arch,
Kernel *kernel) {
std::unique_ptr<KernelCodeGen> KernelCodeGen::create(
const CompileConfig *compile_config,
Kernel *kernel) {
#ifdef TI_WITH_LLVM
const auto arch = compile_config->arch;
if (arch_is_cpu(arch) && arch != Arch::wasm) {
return std::make_unique<KernelCodeGenCPU>(kernel);
return std::make_unique<KernelCodeGenCPU>(compile_config, kernel);
} else if (arch == Arch::wasm) {
return std::make_unique<KernelCodeGenWASM>(kernel);
return std::make_unique<KernelCodeGenWASM>(compile_config, kernel);
} else if (arch == Arch::cuda) {
#if defined(TI_WITH_CUDA)
return std::make_unique<KernelCodeGenCUDA>(kernel);
return std::make_unique<KernelCodeGenCUDA>(compile_config, kernel);
#else
TI_NOT_IMPLEMENTED
#endif
} else if (arch == Arch::dx12) {
#if defined(TI_WITH_DX12)
return std::make_unique<KernelCodeGenDX12>(kernel);
return std::make_unique<KernelCodeGenDX12>(compile_config, kernel);
#else
TI_NOT_IMPLEMENTED
#endif
Expand All @@ -58,7 +61,7 @@ std::optional<LLVMCompiledKernel>
KernelCodeGen::maybe_read_compilation_from_cache(
const std::string &kernel_key) {
TI_AUTO_PROF;
const auto &config = prog->this_thread_config();
const auto &config = *compile_config_;
auto *llvm_prog = get_llvm_program(prog);
const auto &reader = llvm_prog->get_cache_reader();
if (!reader) {
Expand All @@ -82,8 +85,8 @@ void KernelCodeGen::cache_kernel(const std::string &kernel_key,
}

LLVMCompiledKernel KernelCodeGen::compile_kernel_to_module() {
const auto &config = *compile_config_;
auto *llvm_prog = get_llvm_program(prog);
const auto &config = prog->this_thread_config();
auto *tlctx = llvm_prog->get_llvm_context(config.arch);
std::string kernel_key = get_hashed_offline_cache_key(&config, kernel);
kernel->set_kernel_key_for_cache(kernel_key);
Expand Down
8 changes: 6 additions & 2 deletions taichi/codegen/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ class KernelCodeGen {
IRNode *ir;

public:
explicit KernelCodeGen(Kernel *kernel);
explicit KernelCodeGen(const CompileConfig *compile_config, Kernel *kernel);

virtual ~KernelCodeGen() = default;

static std::unique_ptr<KernelCodeGen> create(Arch arch, Kernel *kernel);
static std::unique_ptr<KernelCodeGen> create(
const CompileConfig *compile_config,
Kernel *kernel);

virtual FunctionType compile_to_function() = 0;
virtual bool supports_offline_cache() const {
Expand All @@ -65,6 +67,8 @@ class KernelCodeGen {
void cache_kernel(const std::string &kernel_key,
const LLVMCompiledKernel &data);
#endif
private:
const CompileConfig *compile_config_{nullptr};
};

#ifdef TI_WITH_LLVM
Expand Down
3 changes: 2 additions & 1 deletion taichi/codegen/cpu/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ namespace taichi::lang {

class KernelCodeGenCPU : public KernelCodeGen {
public:
explicit KernelCodeGenCPU(Kernel *kernel) : KernelCodeGen(kernel) {
explicit KernelCodeGenCPU(const CompileConfig *compile_config, Kernel *kernel)
: KernelCodeGen(compile_config, kernel) {
}

// TODO: Stop defining this macro guards in the headers
Expand Down
4 changes: 3 additions & 1 deletion taichi/codegen/cuda/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ namespace taichi::lang {

class KernelCodeGenCUDA : public KernelCodeGen {
public:
explicit KernelCodeGenCUDA(Kernel *kernel) : KernelCodeGen(kernel) {
explicit KernelCodeGenCUDA(const CompileConfig *compile_config,
Kernel *kernel)
: KernelCodeGen(compile_config, kernel) {
}

// TODO: Stop defining this macro guards in the headers
Expand Down
4 changes: 3 additions & 1 deletion taichi/codegen/dx12/codegen_dx12.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ namespace taichi::lang {

class KernelCodeGenDX12 : public KernelCodeGen {
public:
explicit KernelCodeGenDX12(Kernel *kernel) : KernelCodeGen(kernel) {
explicit KernelCodeGenDX12(const CompileConfig *compile_config,
Kernel *kernel)
: KernelCodeGen(compile_config, kernel) {
}
struct CompileResult {
std::vector<std::vector<uint8_t>> task_dxil_source_codes;
Expand Down
4 changes: 3 additions & 1 deletion taichi/codegen/wasm/codegen_wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ namespace taichi::lang {

class KernelCodeGenWASM : public KernelCodeGen {
public:
explicit KernelCodeGenWASM(Kernel *kernel) : KernelCodeGen(kernel) {
explicit KernelCodeGenWASM(const CompileConfig *compile_config,
Kernel *kernel)
: KernelCodeGen(compile_config, kernel) {
}

FunctionType compile_to_function() override;
Expand Down
2 changes: 1 addition & 1 deletion taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ std::unique_ptr<AotModuleBuilder> Program::make_aot_module_builder(
if (arch == Arch::wasm) {
// Have to check WASM first, or it dispatches to the LlvmProgramImpl.
#ifdef TI_WITH_LLVM
return std::make_unique<wasm::AotModuleBuilderImpl>();
return std::make_unique<wasm::AotModuleBuilderImpl>(&this_thread_config());
#else
TI_NOT_IMPLEMENTED
#endif
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cpu/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace taichi::lang {
namespace cpu {

LLVMCompiledKernel AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
auto cgen = KernelCodeGenCPU(kernel);
auto cgen = KernelCodeGenCPU(get_compile_config(), kernel);
return cgen.compile_kernel_to_module();
}

Expand Down
5 changes: 3 additions & 2 deletions taichi/runtime/cpu/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ namespace cpu {

class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
public:
explicit AotModuleBuilderImpl(LlvmProgramImpl *prog)
: LlvmAotModuleBuilder(prog) {
explicit AotModuleBuilderImpl(const CompileConfig *compile_config,
LlvmProgramImpl *prog)
: LlvmAotModuleBuilder(compile_config, prog) {
}

private:
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cuda/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace taichi::lang {
namespace cuda {

LLVMCompiledKernel AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
auto cgen = KernelCodeGenCUDA(kernel);
auto cgen = KernelCodeGenCUDA(get_compile_config(), kernel);
return cgen.compile_kernel_to_module();
}

Expand Down
5 changes: 3 additions & 2 deletions taichi/runtime/cuda/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ namespace cuda {

class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
public:
explicit AotModuleBuilderImpl(LlvmProgramImpl *prog)
: LlvmAotModuleBuilder(prog) {
explicit AotModuleBuilderImpl(const CompileConfig *compile_config,
LlvmProgramImpl *prog)
: LlvmAotModuleBuilder(compile_config, prog) {
}

private:
Expand Down
8 changes: 5 additions & 3 deletions taichi/runtime/dx12/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
namespace taichi::lang {
namespace directx12 {

AotModuleBuilderImpl::AotModuleBuilderImpl(LlvmProgramImpl *prog) : prog(prog) {
AotModuleBuilderImpl::AotModuleBuilderImpl(const CompileConfig &config,
LlvmProgramImpl *prog)
: config_(config), prog(prog) {
// FIXME: set correct root buffer size.
module_data.root_buffer_size = 1;
}
Expand All @@ -19,7 +21,7 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
auto &dxil_codes = module_data.dxil_codes[identifier];
auto &compiled_kernel = module_data.kernels[identifier];

KernelCodeGenDX12 cgen(kernel);
KernelCodeGenDX12 cgen(&config_, kernel);
auto compiled_data = cgen.compile();
for (auto &dxil : compiled_data.task_dxil_source_codes) {
dxil_codes.emplace_back(dxil);
Expand Down Expand Up @@ -69,7 +71,7 @@ void AotModuleBuilderImpl::add_per_backend_tmpl(const std::string &identifier,
auto &dxil_codes = module_data.dxil_codes[tmpl_identifier];
auto &compiled_kernel = module_data.kernels[tmpl_identifier];

KernelCodeGenDX12 cgen(kernel);
KernelCodeGenDX12 cgen(&config_, kernel);
auto compiled_data = cgen.compile();
for (auto &dxil : compiled_data.task_dxil_source_codes) {
dxil_codes.emplace_back(dxil);
Expand Down
4 changes: 3 additions & 1 deletion taichi/runtime/dx12/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ struct ModuleDataDX12 : public aot::ModuleData {

class AotModuleBuilderImpl : public AotModuleBuilder {
public:
explicit AotModuleBuilderImpl(LlvmProgramImpl *prog);
explicit AotModuleBuilderImpl(const CompileConfig &config,
LlvmProgramImpl *prog);

void dump(const std::string &output_dir,
const std::string &filename) const override;
Expand All @@ -34,6 +35,7 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
const std::string &key,
Kernel *kernel) override;

const CompileConfig &config_;
LlvmProgramImpl *prog;
ModuleDataDX12 module_data;
};
Expand Down
9 changes: 8 additions & 1 deletion taichi/runtime/llvm/llvm_aot_module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ namespace taichi::lang {

class LlvmAotModuleBuilder : public AotModuleBuilder {
public:
explicit LlvmAotModuleBuilder(LlvmProgramImpl *prog) : prog_(prog) {
explicit LlvmAotModuleBuilder(const CompileConfig *compile_config,
LlvmProgramImpl *prog)
: compile_config_(compile_config), prog_(prog) {
}

void dump(const std::string &output_dir,
Expand All @@ -30,8 +32,13 @@ class LlvmAotModuleBuilder : public AotModuleBuilder {
return cache_;
}

const CompileConfig *get_compile_config() const {
return compile_config_;
}

private:
mutable LlvmOfflineCache cache_;
const CompileConfig *compile_config_{nullptr};
LlvmProgramImpl *prog_ = nullptr;
};

Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/program_impls/dx12/dx12_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Dx12ProgramImpl::Dx12ProgramImpl(CompileConfig &config)

std::unique_ptr<AotModuleBuilder> Dx12ProgramImpl::make_aot_module_builder(
const DeviceCapabilityConfig &caps) {
return std::make_unique<directx12::AotModuleBuilderImpl>(this);
return std::make_unique<directx12::AotModuleBuilderImpl>(*config, this);
}

} // namespace lang
Expand Down
8 changes: 4 additions & 4 deletions taichi/runtime/program_impls/llvm/llvm_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ LlvmProgramImpl::LlvmProgramImpl(CompileConfig &config_,
}

FunctionType LlvmProgramImpl::compile(Kernel *kernel) {
auto codegen = KernelCodeGen::create(config->arch, kernel);
auto codegen = KernelCodeGen::create(config, kernel);
return codegen->compile_to_function();
}

Expand Down Expand Up @@ -89,18 +89,18 @@ void LlvmProgramImpl::materialize_snode_tree(SNodeTree *tree,
std::unique_ptr<AotModuleBuilder> LlvmProgramImpl::make_aot_module_builder(
const DeviceCapabilityConfig &caps) {
if (config->arch == Arch::x64 || config->arch == Arch::arm64) {
return std::make_unique<cpu::AotModuleBuilderImpl>(this);
return std::make_unique<cpu::AotModuleBuilderImpl>(config, this);
}

#if defined(TI_WITH_CUDA)
if (config->arch == Arch::cuda) {
return std::make_unique<cuda::AotModuleBuilderImpl>(this);
return std::make_unique<cuda::AotModuleBuilderImpl>(config, this);
}
#endif

#if defined(TI_WITH_DX12)
if (config->arch == Arch::dx12) {
return std::make_unique<directx12::AotModuleBuilderImpl>(this);
return std::make_unique<directx12::AotModuleBuilderImpl>(*config, this);
}
#endif

Expand Down
6 changes: 4 additions & 2 deletions taichi/runtime/wasm/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
namespace taichi::lang {
namespace wasm {

AotModuleBuilderImpl::AotModuleBuilderImpl() : module_(nullptr) {
AotModuleBuilderImpl::AotModuleBuilderImpl(const CompileConfig *compile_config)
: compile_config_(compile_config), module_(nullptr) {
TI_AUTO_PROF
}

Expand All @@ -35,7 +36,8 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir,

void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Kernel *kernel) {
auto module_info = KernelCodeGenWASM(kernel).compile_kernel_to_module();
auto module_info =
KernelCodeGenWASM(compile_config_, kernel).compile_kernel_to_module();
if (module_) {
llvm::Linker::linkModules(*module_, std::move(module_info.module),
llvm::Linker::OverrideFromSrc);
Expand Down
3 changes: 2 additions & 1 deletion taichi/runtime/wasm/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace wasm {

class AotModuleBuilderImpl : public AotModuleBuilder {
public:
explicit AotModuleBuilderImpl();
explicit AotModuleBuilderImpl(const CompileConfig *compile_config);

void dump(const std::string &output_dir,
const std::string &filename) const override;
Expand All @@ -34,6 +34,7 @@ class AotModuleBuilderImpl : public AotModuleBuilder {

private:
void eliminate_unused_functions() const;
const CompileConfig *compile_config_{nullptr};
std::unique_ptr<llvm::Module> module_{nullptr};
std::vector<std::string> name_list_;
};
Expand Down

0 comments on commit e3d8f73

Please sign in to comment.