Skip to content

Commit

Permalink
Refactor2023:Use CKD+KC+KCM to re-impl gfx::AotModuleBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
PGZXB committed Dec 18, 2022
1 parent a77e46a commit 0697032
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 19 deletions.
2 changes: 1 addition & 1 deletion taichi/cache/gfx/cache_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ CacheManager::CacheManager(Params &&init_params)
}

caching_module_builder_ = std::make_unique<gfx::AotModuleBuilderImpl>(
compiled_structs_, init_params.arch, compile_config_,
compiled_structs_, nullptr, init_params.arch, compile_config_,
std::move(init_params.caps));

offline_cache_metadata_.version[0] = TI_VERSION_MAJOR;
Expand Down
35 changes: 26 additions & 9 deletions taichi/runtime/gfx/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "taichi/program/kernel.h"
#include "taichi/aot/module_data.h"
#include "taichi/codegen/spirv/spirv_codegen.h"
#include "taichi/codegen/spirv/cxompiled_kernel_data.h"
#include "taichi/runtime/gfx/aot_graph_data.h"

namespace taichi::lang {
Expand Down Expand Up @@ -104,10 +105,12 @@ class AotDataConverter {
} // namespace
AotModuleBuilderImpl::AotModuleBuilderImpl(
const std::vector<CompiledSNodeStructs> &compiled_structs,
KernelCompilationManager *kernel_compilation_manager,
Arch device_api_backend,
const CompileConfig &compile_config,
const DeviceCapabilityConfig &caps)
: compiled_structs_(compiled_structs),
kernel_compilation_manager_(kernel_compilation_manager),
device_api_backend_(device_api_backend),
config_(compile_config),
caps_(caps) {
Expand Down Expand Up @@ -200,10 +203,7 @@ AotModuleBuilderImpl::try_get_kernel_register_params(

void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Kernel *kernel) {
spirv::lower(config_, kernel);
auto compiled =
run_codegen(kernel, this->device_api_backend_, caps_, compiled_structs_,
config_.external_optimization_level > 0);
auto compiled = compile_kernel(*kernel);
compiled.kernel_attribs.name = identifier;
ti_aot_data_.kernels.push_back(compiled.kernel_attribs);
ti_aot_data_.spirv_codes.push_back(compiled.task_spirv_source_codes);
Expand Down Expand Up @@ -243,15 +243,32 @@ void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
void AotModuleBuilderImpl::add_per_backend_tmpl(const std::string &identifier,
const std::string &key,
Kernel *kernel) {
spirv::lower(config_, kernel);
auto compiled =
run_codegen(kernel, device_api_backend_, caps_, compiled_structs_,
config_.external_optimization_level > 0);

auto compiled = compile_kernel(*kernel);
compiled.kernel_attribs.name = identifier + "|" + key;
ti_aot_data_.kernels.push_back(compiled.kernel_attribs);
ti_aot_data_.spirv_codes.push_back(compiled.task_spirv_source_codes);
}

GfxRuntime::RegisterParams AotModuleBuilderImpl::compile_kernel(
const Kernel &kernel_def) {
if (kernel_compilation_manager_) {
const auto &compiled = kernel_compilation_manager_->load_or_compile(
config_, caps_, kernel_def);
const auto &spirv_compiled =
dynamic_cast<const spirv::CompiledKernelData &>(compiled);
const auto &spirv_data = spirv_compiled.get_internal_data();
gfx::GfxRuntime::RegisterParams params;
params.kernel_attribs = spirv_data.kernel_attribs;
params.task_spirv_source_codes = spirv_data.spirv_src;
params.num_snode_trees = spirv_data.num_snode_trees;
return params;
} else { // Refactor2023:FIXME: Remove the branch
auto *kernel = const_cast<Kernel *>(&kernel_def);
spirv::lower(config_, kernel);
return run_codegen(kernel, device_api_backend_, caps_, compiled_structs_,
config_.external_optimization_level > 0);
}
}

} // namespace gfx
} // namespace taichi::lang
5 changes: 5 additions & 0 deletions taichi/runtime/gfx/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "taichi/aot/module_builder.h"
#include "taichi/runtime/gfx/aot_utils.h"
#include "taichi/runtime/gfx/runtime.h"
#include "taichi/cache/kernel_compilation_manager.h"
#include "taichi/codegen/spirv/snode_struct_compiler.h"
#include "taichi/codegen/spirv/kernel_utils.h"

Expand All @@ -16,6 +17,7 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
public:
explicit AotModuleBuilderImpl(
const std::vector<CompiledSNodeStructs> &compiled_structs,
KernelCompilationManager *kernel_compilation_manager,
Arch device_api_backend,
const CompileConfig &compile_config,
const DeviceCapabilityConfig &caps);
Expand Down Expand Up @@ -47,7 +49,10 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
const TaskAttributes &k,
const std::vector<uint32_t> &source_code) const;

GfxRuntime::RegisterParams compile_kernel(const Kernel &kernel_def);

const std::vector<CompiledSNodeStructs> &compiled_structs_;
KernelCompilationManager *kernel_compilation_manager_{nullptr};
TaichiAotData ti_aot_data_;

Arch device_api_backend_;
Expand Down
5 changes: 3 additions & 2 deletions taichi/runtime/program_impls/dx/dx_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ std::unique_ptr<AotModuleBuilder> Dx11ProgramImpl::make_aot_module_builder(
const DeviceCapabilityConfig &caps) {
if (runtime_) {
return std::make_unique<gfx::AotModuleBuilderImpl>(
snode_tree_mgr_->get_compiled_structs(), Arch::dx11, *config, caps);
snode_tree_mgr_->get_compiled_structs(), nullptr, Arch::dx11, *config,
caps);
} else {
return std::make_unique<gfx::AotModuleBuilderImpl>(
aot_compiled_snode_structs_, Arch::dx11, *config, caps);
aot_compiled_snode_structs_, nullptr, Arch::dx11, *config, caps);
}
}

Expand Down
5 changes: 3 additions & 2 deletions taichi/runtime/program_impls/opengl/opengl_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ std::unique_ptr<AotModuleBuilder> OpenglProgramImpl::make_aot_module_builder(
const DeviceCapabilityConfig &caps) {
if (runtime_) {
return std::make_unique<gfx::AotModuleBuilderImpl>(
snode_tree_mgr_->get_compiled_structs(), Arch::opengl, *config, caps);
snode_tree_mgr_->get_compiled_structs(), nullptr, Arch::opengl, *config,
caps);
} else {
return std::make_unique<gfx::AotModuleBuilderImpl>(
aot_compiled_snode_structs_, Arch::opengl, *config, caps);
aot_compiled_snode_structs_, nullptr, Arch::opengl, *config, caps);
}
}

Expand Down
14 changes: 9 additions & 5 deletions taichi/runtime/program_impls/vulkan/vulkan_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,12 @@ std::unique_ptr<AotModuleBuilder> VulkanProgramImpl::make_aot_module_builder(
const DeviceCapabilityConfig &caps) {
if (vulkan_runtime_) {
return std::make_unique<gfx::AotModuleBuilderImpl>(
snode_tree_mgr_->get_compiled_structs(), Arch::vulkan, *config, caps);
snode_tree_mgr_->get_compiled_structs(),
get_kernel_compilation_maanger().get(), Arch::vulkan, *config, caps);
} else {
return std::make_unique<gfx::AotModuleBuilderImpl>(
aot_compiled_snode_structs_, Arch::vulkan, *config, caps);
aot_compiled_snode_structs_, get_kernel_compilation_maanger().get(),
Arch::vulkan, *config, caps);
}
}

Expand Down Expand Up @@ -240,10 +242,12 @@ const std::unique_ptr<KernelCompilationManager>
KernelCompilationManager::Config init_config;
// FIXME: Rm CompileConfig::offline_cache_file_path &
// Mv it to TaichiConfig
const auto *structs = vulkan_runtime_
? &snode_tree_mgr_->get_compiled_structs()
: &aot_compiled_snode_structs_;
init_config.offline_cache_path = config->offline_cache_file_path;
init_config.kernel_compiler =
std::make_unique<spirv::KernelCompiler>(spirv::KernelCompiler::Config{
&snode_tree_mgr_->get_compiled_structs()});
init_config.kernel_compiler = std::make_unique<spirv::KernelCompiler>(
spirv::KernelCompiler::Config{structs});
kernel_compilation_mgr_ =
std::make_unique<KernelCompilationManager>(std::move(init_config));
}
Expand Down

0 comments on commit 0697032

Please sign in to comment.