Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Remove dependencies on Program::this_thread_config() in KernelCodeGen #7086

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion taichi/codegen/amdgpu/codegen_amdgpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ namespace lang {

class KernelCodeGenAMDGPU : public KernelCodeGen {
public:
KernelCodeGenAMDGPU(Kernel *kernel) : KernelCodeGen(kernel) {
KernelCodeGenAMDGPU(const CompileConfig *config, Kernel *kernel)
: KernelCodeGen(config, kernel) {
}

// TODO: Stop defining this macro guards in the headers
Expand Down
23 changes: 13 additions & 10 deletions taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,30 @@

namespace taichi::lang {

KernelCodeGen::KernelCodeGen(Kernel *kernel)
: prog(kernel->program), kernel(kernel) {
KernelCodeGen::KernelCodeGen(const CompileConfig *compile_config,
Kernel *kernel)
: prog(kernel->program), kernel(kernel), compile_config_(compile_config) {
this->ir = kernel->ir.get();
}

std::unique_ptr<KernelCodeGen> KernelCodeGen::create(Arch arch,
Kernel *kernel) {
std::unique_ptr<KernelCodeGen> KernelCodeGen::create(
const CompileConfig *compile_config,
Kernel *kernel) {
#ifdef TI_WITH_LLVM
const auto arch = compile_config->arch;
if (arch_is_cpu(arch) && arch != Arch::wasm) {
return std::make_unique<KernelCodeGenCPU>(kernel);
return std::make_unique<KernelCodeGenCPU>(compile_config, kernel);
} else if (arch == Arch::wasm) {
return std::make_unique<KernelCodeGenWASM>(kernel);
return std::make_unique<KernelCodeGenWASM>(compile_config, kernel);
} else if (arch == Arch::cuda) {
#if defined(TI_WITH_CUDA)
return std::make_unique<KernelCodeGenCUDA>(kernel);
return std::make_unique<KernelCodeGenCUDA>(compile_config, kernel);
#else
TI_NOT_IMPLEMENTED
#endif
} else if (arch == Arch::dx12) {
#if defined(TI_WITH_DX12)
return std::make_unique<KernelCodeGenDX12>(kernel);
return std::make_unique<KernelCodeGenDX12>(compile_config, kernel);
#else
TI_NOT_IMPLEMENTED
#endif
Expand All @@ -58,7 +61,7 @@ std::optional<LLVMCompiledKernel>
KernelCodeGen::maybe_read_compilation_from_cache(
const std::string &kernel_key) {
TI_AUTO_PROF;
const auto &config = prog->this_thread_config();
const auto &config = *compile_config_;
auto *llvm_prog = get_llvm_program(prog);
const auto &reader = llvm_prog->get_cache_reader();
if (!reader) {
Expand All @@ -83,8 +86,8 @@ void KernelCodeGen::cache_kernel(const std::string &kernel_key,
}

LLVMCompiledKernel KernelCodeGen::compile_kernel_to_module() {
const auto &config = *compile_config_;
auto *llvm_prog = get_llvm_program(prog);
const auto &config = prog->this_thread_config();
auto *tlctx = llvm_prog->get_llvm_context(config.arch);
std::string kernel_key = get_hashed_offline_cache_key(&config, kernel);
kernel->set_kernel_key_for_cache(kernel_key);
Expand Down
8 changes: 6 additions & 2 deletions taichi/codegen/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ class KernelCodeGen {
IRNode *ir;

public:
explicit KernelCodeGen(Kernel *kernel);
explicit KernelCodeGen(const CompileConfig *compile_config, Kernel *kernel);

virtual ~KernelCodeGen() = default;

static std::unique_ptr<KernelCodeGen> create(Arch arch, Kernel *kernel);
static std::unique_ptr<KernelCodeGen> create(
const CompileConfig *compile_config,
Kernel *kernel);

virtual FunctionType compile_to_function() = 0;
virtual bool supports_offline_cache() const {
Expand All @@ -65,6 +67,8 @@ class KernelCodeGen {
void cache_kernel(const std::string &kernel_key,
const LLVMCompiledKernel &data);
#endif
private:
const CompileConfig *compile_config_{nullptr};
};

#ifdef TI_WITH_LLVM
Expand Down
3 changes: 2 additions & 1 deletion taichi/codegen/cpu/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ namespace taichi::lang {

class KernelCodeGenCPU : public KernelCodeGen {
public:
explicit KernelCodeGenCPU(Kernel *kernel) : KernelCodeGen(kernel) {
explicit KernelCodeGenCPU(const CompileConfig *compile_config, Kernel *kernel)
: KernelCodeGen(compile_config, kernel) {
}

// TODO: Stop defining this macro guards in the headers
Expand Down
4 changes: 3 additions & 1 deletion taichi/codegen/cuda/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ namespace taichi::lang {

class KernelCodeGenCUDA : public KernelCodeGen {
public:
explicit KernelCodeGenCUDA(Kernel *kernel) : KernelCodeGen(kernel) {
explicit KernelCodeGenCUDA(const CompileConfig *compile_config,
Kernel *kernel)
: KernelCodeGen(compile_config, kernel) {
}

// TODO: Stop defining this macro guards in the headers
Expand Down
4 changes: 3 additions & 1 deletion taichi/codegen/dx12/codegen_dx12.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ namespace taichi::lang {

class KernelCodeGenDX12 : public KernelCodeGen {
public:
explicit KernelCodeGenDX12(Kernel *kernel) : KernelCodeGen(kernel) {
explicit KernelCodeGenDX12(const CompileConfig *compile_config,
Kernel *kernel)
: KernelCodeGen(compile_config, kernel) {
}
struct CompileResult {
std::vector<std::vector<uint8_t>> task_dxil_source_codes;
Expand Down
4 changes: 3 additions & 1 deletion taichi/codegen/wasm/codegen_wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ namespace taichi::lang {

class KernelCodeGenWASM : public KernelCodeGen {
public:
explicit KernelCodeGenWASM(Kernel *kernel) : KernelCodeGen(kernel) {
explicit KernelCodeGenWASM(const CompileConfig *compile_config,
Kernel *kernel)
: KernelCodeGen(compile_config, kernel) {
}

FunctionType compile_to_function() override;
Expand Down
2 changes: 1 addition & 1 deletion taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ std::unique_ptr<AotModuleBuilder> Program::make_aot_module_builder(
if (arch == Arch::wasm) {
// Have to check WASM first, or it dispatches to the LlvmProgramImpl.
#ifdef TI_WITH_LLVM
return std::make_unique<wasm::AotModuleBuilderImpl>();
return std::make_unique<wasm::AotModuleBuilderImpl>(&this_thread_config());
#else
TI_NOT_IMPLEMENTED
#endif
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cpu/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace taichi::lang {
namespace cpu {

LLVMCompiledKernel AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
auto cgen = KernelCodeGenCPU(kernel);
auto cgen = KernelCodeGenCPU(get_compile_config(), kernel);
return cgen.compile_kernel_to_module();
}

Expand Down
5 changes: 3 additions & 2 deletions taichi/runtime/cpu/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ namespace cpu {

class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
public:
explicit AotModuleBuilderImpl(LlvmProgramImpl *prog)
: LlvmAotModuleBuilder(prog) {
explicit AotModuleBuilderImpl(const CompileConfig *compile_config,
LlvmProgramImpl *prog)
: LlvmAotModuleBuilder(compile_config, prog) {
}

private:
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cuda/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace taichi::lang {
namespace cuda {

LLVMCompiledKernel AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
auto cgen = KernelCodeGenCUDA(kernel);
auto cgen = KernelCodeGenCUDA(get_compile_config(), kernel);
return cgen.compile_kernel_to_module();
}

Expand Down
5 changes: 3 additions & 2 deletions taichi/runtime/cuda/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ namespace cuda {

class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
public:
explicit AotModuleBuilderImpl(LlvmProgramImpl *prog)
: LlvmAotModuleBuilder(prog) {
explicit AotModuleBuilderImpl(const CompileConfig *compile_config,
LlvmProgramImpl *prog)
: LlvmAotModuleBuilder(compile_config, prog) {
}

private:
Expand Down
8 changes: 5 additions & 3 deletions taichi/runtime/dx12/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
namespace taichi::lang {
namespace directx12 {

AotModuleBuilderImpl::AotModuleBuilderImpl(LlvmProgramImpl *prog) : prog(prog) {
AotModuleBuilderImpl::AotModuleBuilderImpl(const CompileConfig &config,
LlvmProgramImpl *prog)
: config_(config), prog(prog) {
// FIXME: set correct root buffer size.
module_data.root_buffer_size = 1;
}
Expand All @@ -19,7 +21,7 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
auto &dxil_codes = module_data.dxil_codes[identifier];
auto &compiled_kernel = module_data.kernels[identifier];

KernelCodeGenDX12 cgen(kernel);
KernelCodeGenDX12 cgen(&config_, kernel);
auto compiled_data = cgen.compile();
for (auto &dxil : compiled_data.task_dxil_source_codes) {
dxil_codes.emplace_back(dxil);
Expand Down Expand Up @@ -69,7 +71,7 @@ void AotModuleBuilderImpl::add_per_backend_tmpl(const std::string &identifier,
auto &dxil_codes = module_data.dxil_codes[tmpl_identifier];
auto &compiled_kernel = module_data.kernels[tmpl_identifier];

KernelCodeGenDX12 cgen(kernel);
KernelCodeGenDX12 cgen(&config_, kernel);
auto compiled_data = cgen.compile();
for (auto &dxil : compiled_data.task_dxil_source_codes) {
dxil_codes.emplace_back(dxil);
Expand Down
4 changes: 3 additions & 1 deletion taichi/runtime/dx12/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ struct ModuleDataDX12 : public aot::ModuleData {

class AotModuleBuilderImpl : public AotModuleBuilder {
public:
explicit AotModuleBuilderImpl(LlvmProgramImpl *prog);
explicit AotModuleBuilderImpl(const CompileConfig &config,
LlvmProgramImpl *prog);

void dump(const std::string &output_dir,
const std::string &filename) const override;
Expand All @@ -34,6 +35,7 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
const std::string &key,
Kernel *kernel) override;

const CompileConfig &config_;
LlvmProgramImpl *prog;
ModuleDataDX12 module_data;
};
Expand Down
9 changes: 8 additions & 1 deletion taichi/runtime/llvm/llvm_aot_module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ namespace taichi::lang {

class LlvmAotModuleBuilder : public AotModuleBuilder {
public:
explicit LlvmAotModuleBuilder(LlvmProgramImpl *prog) : prog_(prog) {
explicit LlvmAotModuleBuilder(const CompileConfig *compile_config,
LlvmProgramImpl *prog)
: compile_config_(compile_config), prog_(prog) {
}

void dump(const std::string &output_dir,
Expand All @@ -30,8 +32,13 @@ class LlvmAotModuleBuilder : public AotModuleBuilder {
return cache_;
}

const CompileConfig *get_compile_config() const {
return compile_config_;
}

private:
mutable LlvmOfflineCache cache_;
const CompileConfig *compile_config_{nullptr};
LlvmProgramImpl *prog_ = nullptr;
};

Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/program_impls/dx12/dx12_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Dx12ProgramImpl::Dx12ProgramImpl(CompileConfig &config)

std::unique_ptr<AotModuleBuilder> Dx12ProgramImpl::make_aot_module_builder(
const DeviceCapabilityConfig &caps) {
return std::make_unique<directx12::AotModuleBuilderImpl>(this);
return std::make_unique<directx12::AotModuleBuilderImpl>(*config, this);
}

} // namespace lang
Expand Down
8 changes: 4 additions & 4 deletions taichi/runtime/program_impls/llvm/llvm_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ LlvmProgramImpl::LlvmProgramImpl(CompileConfig &config_,
}

FunctionType LlvmProgramImpl::compile(Kernel *kernel) {
auto codegen = KernelCodeGen::create(config->arch, kernel);
auto codegen = KernelCodeGen::create(config, kernel);
return codegen->compile_to_function();
}

Expand Down Expand Up @@ -89,18 +89,18 @@ void LlvmProgramImpl::materialize_snode_tree(SNodeTree *tree,
std::unique_ptr<AotModuleBuilder> LlvmProgramImpl::make_aot_module_builder(
const DeviceCapabilityConfig &caps) {
if (config->arch == Arch::x64 || config->arch == Arch::arm64) {
return std::make_unique<cpu::AotModuleBuilderImpl>(this);
return std::make_unique<cpu::AotModuleBuilderImpl>(config, this);
}

#if defined(TI_WITH_CUDA)
if (config->arch == Arch::cuda) {
return std::make_unique<cuda::AotModuleBuilderImpl>(this);
return std::make_unique<cuda::AotModuleBuilderImpl>(config, this);
}
#endif

#if defined(TI_WITH_DX12)
if (config->arch == Arch::dx12) {
return std::make_unique<directx12::AotModuleBuilderImpl>(this);
return std::make_unique<directx12::AotModuleBuilderImpl>(*config, this);
}
#endif

Expand Down
6 changes: 4 additions & 2 deletions taichi/runtime/wasm/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
namespace taichi::lang {
namespace wasm {

AotModuleBuilderImpl::AotModuleBuilderImpl() : module_(nullptr) {
AotModuleBuilderImpl::AotModuleBuilderImpl(const CompileConfig *compile_config)
: compile_config_(compile_config), module_(nullptr) {
TI_AUTO_PROF
}

Expand All @@ -35,7 +36,8 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir,

void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Kernel *kernel) {
auto module_info = KernelCodeGenWASM(kernel).compile_kernel_to_module();
auto module_info =
KernelCodeGenWASM(compile_config_, kernel).compile_kernel_to_module();
if (module_) {
llvm::Linker::linkModules(*module_, std::move(module_info.module),
llvm::Linker::OverrideFromSrc);
Expand Down
3 changes: 2 additions & 1 deletion taichi/runtime/wasm/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace wasm {

class AotModuleBuilderImpl : public AotModuleBuilder {
public:
explicit AotModuleBuilderImpl();
explicit AotModuleBuilderImpl(const CompileConfig *compile_config);

void dump(const std::string &output_dir,
const std::string &filename) const override;
Expand All @@ -34,6 +34,7 @@ class AotModuleBuilderImpl : public AotModuleBuilder {

private:
void eliminate_unused_functions() const;
const CompileConfig *compile_config_{nullptr};
std::unique_ptr<llvm::Module> module_{nullptr};
std::vector<std::string> name_list_;
};
Expand Down