Skip to content

Commit

Permalink
[refactor] Remove unnecessary Kernel::arch (#7074)
Browse files Browse the repository at this point in the history
Issue: #7002
  • Loading branch information
PGZXB committed Jan 9, 2023
1 parent 539d2a5 commit 04150ac
Show file tree
Hide file tree
Showing 11 changed files with 29 additions and 56 deletions.
4 changes: 2 additions & 2 deletions taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ void KernelCodeGen::cache_kernel(const std::string &kernel_key,

LLVMCompiledKernel KernelCodeGen::compile_kernel_to_module() {
auto *llvm_prog = get_llvm_program(prog);
auto *tlctx = llvm_prog->get_llvm_context(kernel->arch);
auto &config = prog->this_thread_config();
const auto &config = prog->this_thread_config();
auto *tlctx = llvm_prog->get_llvm_context(config.arch);
std::string kernel_key = get_hashed_offline_cache_key(&config, kernel);
kernel->set_kernel_key_for_cache(kernel_key);
if (config.offline_cache && this->supports_offline_cache() &&
Expand Down
3 changes: 2 additions & 1 deletion taichi/codegen/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ LLVMCompiledTask KernelCodeGenCPU::compile_task(
FunctionType KernelCodeGenCPU::compile_to_function() {
TI_AUTO_PROF;
auto *llvm_prog = get_llvm_program(prog);
auto *tlctx = llvm_prog->get_llvm_context(kernel->arch);
const auto &config = prog->this_thread_config();
auto *tlctx = llvm_prog->get_llvm_context(config.arch);

CPUModuleToFunctionConverter converter(
tlctx, get_llvm_program(prog)->get_runtime_executor());
Expand Down
3 changes: 2 additions & 1 deletion taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,8 @@ LLVMCompiledTask KernelCodeGenCUDA::compile_task(
FunctionType KernelCodeGenCUDA::compile_to_function() {
TI_AUTO_PROF
auto *llvm_prog = get_llvm_program(prog);
auto *tlctx = llvm_prog->get_llvm_context(kernel->arch);
const auto &config = prog->this_thread_config();
auto *tlctx = llvm_prog->get_llvm_context(config.arch);

CUDAModuleToFunctionConverter converter{tlctx,
llvm_prog->get_runtime_executor()};
Expand Down
21 changes: 13 additions & 8 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,14 @@ TaskCodeGenLLVM::TaskCodeGenLLVM(Kernel *kernel,
std::unique_ptr<llvm::Module> &&module)
// TODO: simplify LLVMModuleBuilder ctor input
: LLVMModuleBuilder(
module == nullptr ? get_llvm_program(kernel->program)
->get_llvm_context(kernel->arch)
->new_module("kernel")
: std::move(module),
get_llvm_program(kernel->program)->get_llvm_context(kernel->arch)),
module == nullptr
? get_llvm_program(kernel->program)
->get_llvm_context(
kernel->program->this_thread_config().arch)
->new_module("kernel")
: std::move(module),
get_llvm_program(kernel->program)
->get_llvm_context(kernel->program->this_thread_config().arch)),
kernel(kernel),
ir(ir),
prog(kernel->program) {
Expand Down Expand Up @@ -884,8 +887,9 @@ void TaskCodeGenLLVM::visit(IfStmt *if_stmt) {
llvm::Value *TaskCodeGenLLVM::create_print(std::string tag,
DataType dt,
llvm::Value *value) {
if (!arch_is_cpu(kernel->arch)) {
TI_WARN("print not supported on arch {}", arch_name(kernel->arch));
if (!arch_is_cpu(prog->this_thread_config().arch)) {
TI_WARN("print not supported on arch {}",
arch_name(prog->this_thread_config().arch));
return nullptr;
}
std::vector<llvm::Value *> args;
Expand Down Expand Up @@ -2566,7 +2570,8 @@ FunctionCreationGuard TaskCodeGenLLVM::get_function_creation_guard(
}

void TaskCodeGenLLVM::initialize_context() {
tlctx = get_llvm_program(prog)->get_llvm_context(kernel->arch);
tlctx =
get_llvm_program(prog)->get_llvm_context(prog->this_thread_config().arch);
llvm_context = tlctx->get_this_thread_context();
builder = std::make_unique<llvm::IRBuilder<>>(*llvm_context);
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
std::unique_ptr<llvm::Module> &&module = nullptr);

Arch current_arch() {
return kernel->arch;
return prog->this_thread_config().arch;
}

void initialize_context();
Expand Down
6 changes: 4 additions & 2 deletions taichi/codegen/wasm/codegen_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ class TaskCodeGenWASM : public TaskCodeGenLLVM {
FunctionType KernelCodeGenWASM::compile_to_function() {
TI_AUTO_PROF
auto linked = compile_kernel_to_module();
auto *tlctx = get_llvm_program(prog)->get_llvm_context(kernel->arch);
auto *tlctx =
get_llvm_program(prog)->get_llvm_context(prog->this_thread_config().arch);
tlctx->create_jit_module(std::move(linked.module));
auto kernel_symbol = tlctx->lookup_function_pointer(linked.tasks[0].name);
return [=](RuntimeContext &context) {
Expand Down Expand Up @@ -275,7 +276,8 @@ LLVMCompiledTask KernelCodeGenWASM::compile_task(
}

LLVMCompiledKernel KernelCodeGenWASM::compile_kernel_to_module() {
auto *tlctx = get_llvm_program(prog)->get_llvm_context(kernel->arch);
auto *tlctx =
get_llvm_program(prog)->get_llvm_context(prog->this_thread_config().arch);
if (!kernel->lowered()) {
kernel->lower(/*to_executable=*/false);
}
Expand Down
17 changes: 4 additions & 13 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ Kernel::Kernel(Program &program,
compiled_ = nullptr;
ir_is_ast_ = false; // CHI IR

arch = program.this_thread_config().arch;

if (autodiff_mode == AutodiffMode::kNone) {
name = primal_name;
} else if (autodiff_mode == AutodiffMode::kForward) {
Expand All @@ -64,10 +62,10 @@ void Kernel::compile() {

void Kernel::lower(bool to_executable) {
TI_ASSERT(!lowered_);
TI_ASSERT(supports_lowering(arch));
TI_ASSERT(supports_lowering(program->this_thread_config().arch));

CurrentCallableGuard _(program, this);
auto config = program->this_thread_config();
const auto &config = program->this_thread_config();
bool verbose = config.print_ir;
if ((is_accessor && !config.print_accessor_ir) ||
(is_evaluator && !config.print_evaluator_ir))
Expand Down Expand Up @@ -109,8 +107,8 @@ void Kernel::operator()(LaunchContextBuilder &ctx_builder) {

compiled_(ctx_builder.get_context());

program->sync = (program->sync && arch_is_cpu(arch));
// Note that Kernel::arch may be different from program.config.arch
program->sync =
(program->sync && arch_is_cpu(program->this_thread_config().arch));
if (program->this_thread_config().debug &&
(arch_is_cpu(program->this_thread_config().arch) ||
program->this_thread_config().arch == Arch::cuda)) {
Expand Down Expand Up @@ -347,11 +345,6 @@ std::vector<float64> Kernel::get_ret_float_tensor(int i) {
return res;
}

void Kernel::set_arch(Arch arch) {
TI_ASSERT(!compiled_);
this->arch = arch;
}

std::string Kernel::get_name() const {
return name;
}
Expand All @@ -372,8 +365,6 @@ void Kernel::init(Program &program,
ir = context->get_root();
ir_is_ast_ = true;

this->arch = program.this_thread_config().arch;

if (autodiff_mode == AutodiffMode::kNone) {
name = primal_name;
} else if (autodiff_mode == AutodiffMode::kCheckAutodiffValid) {
Expand Down
3 changes: 0 additions & 3 deletions taichi/program/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class TI_DLL_EXPORT Kernel : public Callable {
public:
std::string name;
std::vector<SNode *> no_activate;
Arch arch;

bool is_accessor{false};
bool is_evaluator{false};
Expand Down Expand Up @@ -117,8 +116,6 @@ class TI_DLL_EXPORT Kernel : public Callable {

std::vector<float64> get_ret_float_tensor(int i);

void set_arch(Arch arch);

uint64 get_next_task_id() {
return task_counter_++;
}
Expand Down
22 changes: 0 additions & 22 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,26 +336,6 @@ void Program::visualize_layout(const std::string &fn) {
trash(system(fmt::format("pdflatex {}", fn).c_str()));
}

Arch Program::get_accessor_arch() {
if (this_thread_config().arch == Arch::opengl) {
return Arch::opengl;
} else if (this_thread_config().arch == Arch::vulkan) {
return Arch::vulkan;
} else if (this_thread_config().arch == Arch::cuda) {
return Arch::cuda;
} else if (this_thread_config().arch == Arch::metal) {
return Arch::metal;
} else if (this_thread_config().arch == Arch::cc) {
return Arch::cc;
} else if (this_thread_config().arch == Arch::dx11) {
return Arch::dx11;
} else if (this_thread_config().arch == Arch::dx12) {
return Arch::dx12;
} else {
return get_host_arch();
}
}

Kernel &Program::get_snode_reader(SNode *snode) {
TI_ASSERT(snode->type == SNodeType::place);
auto kernel_name = fmt::format("snode_reader_{}", snode->id);
Expand All @@ -371,7 +351,6 @@ Kernel &Program::get_snode_reader(SNode *snode) {
builder.expr_subscript(Expr(snode_to_fields_.at(snode)), indices)));
builder.insert(std::move(ret));
});
ker.set_arch(get_accessor_arch());
ker.name = kernel_name;
ker.is_accessor = true;
for (int i = 0; i < snode->num_active_indices; i++)
Expand Down Expand Up @@ -399,7 +378,6 @@ Kernel &Program::get_snode_writer(SNode *snode) {
snode->dt->get_compute_type()),
expr->tb);
});
ker.set_arch(get_accessor_arch());
ker.name = kernel_name;
ker.is_accessor = true;
for (int i = 0; i < snode->num_active_indices; i++)
Expand Down
2 changes: 0 additions & 2 deletions taichi/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,6 @@ class TI_DLL_EXPORT Program {
return host_arch();
}

Arch get_accessor_arch();

float64 get_total_compilation_time() {
return total_compilation_time_;
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/program_impls/llvm/llvm_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ LlvmProgramImpl::LlvmProgramImpl(CompileConfig &config_,
}

FunctionType LlvmProgramImpl::compile(Kernel *kernel) {
auto codegen = KernelCodeGen::create(kernel->arch, kernel);
auto codegen = KernelCodeGen::create(config->arch, kernel);
return codegen->compile_to_function();
}

Expand Down

0 comments on commit 04150ac

Please sign in to comment.