Skip to content

Commit

Permalink
[metal] Maintain a print string table per kernel (#6160)
Browse files Browse the repository at this point in the history
Issue: #4401 
* Fix a potential bug in metal AOT
* Prepare for implementing offline cache on metal

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PGZXB and pre-commit-ci[bot] committed Sep 26, 2022
1 parent 05e037d commit c4174e0
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 80 deletions.
21 changes: 8 additions & 13 deletions taichi/codegen/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,12 @@ class KernelCodegenImpl : public IRVisitor {
Kernel *kernel,
const CompiledRuntimeModule *compiled_runtime_module,
const std::vector<CompiledStructs> &compiled_snode_trees,
PrintStringTable *print_strtab,
const Config &config,
OffloadedStmt *offloaded)
: mtl_kernel_prefix_(taichi_kernel_name),
kernel_(kernel),
compiled_runtime_module_(compiled_runtime_module),
compiled_snode_trees_(compiled_snode_trees),
print_strtab_(print_strtab),
cgen_config_(config),
offloaded_(offloaded),
ctx_attribs_(*kernel_) {
Expand All @@ -216,11 +214,13 @@ class KernelCodegenImpl : public IRVisitor {
}

CompiledKernelData run() {
CompiledKernelData res;
print_strtab_ = &res.print_str_table;

emit_headers();
generate_structs();
generate_kernels();

CompiledKernelData res;
res.kernel_name = mtl_kernel_prefix_;
res.kernel_attribs = std::move(ti_kernel_attribs_);
res.ctx_attribs = std::move(ctx_attribs_);
Expand Down Expand Up @@ -1654,7 +1654,7 @@ class KernelCodegenImpl : public IRVisitor {
};
std::unordered_map<int, RootInfo> snode_to_roots_;
std::unordered_map<int, const GetRootStmt *> root_id_to_stmts_;
PrintStringTable *const print_strtab_;
PrintStringTable *print_strtab_{nullptr};
const Config &cgen_config_;
OffloadedStmt *const offloaded_;

Expand All @@ -1675,7 +1675,6 @@ CompiledKernelData run_codegen(
const CompiledRuntimeModule *compiled_runtime_module,
const std::vector<CompiledStructs> &compiled_snode_trees,
Kernel *kernel,
PrintStringTable *strtab,
OffloadedStmt *offloaded) {
const auto id = Program::get_kernel_id();
const auto taichi_kernel_name(
Expand All @@ -1685,8 +1684,7 @@ CompiledKernelData run_codegen(
cgen_config.allow_simdgroup = EnvConfig::instance().is_simdgroup_enabled();

KernelCodegenImpl codegen(taichi_kernel_name, kernel, compiled_runtime_module,
compiled_snode_trees, strtab, cgen_config,
offloaded);
compiled_snode_trees, cgen_config, offloaded);

return codegen.run();
}
Expand All @@ -1697,12 +1695,9 @@ FunctionType compile_to_metal_executable(
const CompiledRuntimeModule *compiled_runtime_module,
const std::vector<CompiledStructs> &compiled_snode_trees,
OffloadedStmt *offloaded) {
const auto compiled_res =
run_codegen(compiled_runtime_module, compiled_snode_trees, kernel,
kernel_mgr->print_strtable(), offloaded);
kernel_mgr->register_taichi_kernel(
compiled_res.kernel_name, compiled_res.source_code,
compiled_res.kernel_attribs, compiled_res.ctx_attribs, kernel);
const auto compiled_res = run_codegen(
compiled_runtime_module, compiled_snode_trees, kernel, offloaded);
kernel_mgr->register_taichi_kernel(compiled_res);
return [kernel_mgr,
kernel_name = compiled_res.kernel_name](RuntimeContext &ctx) {
kernel_mgr->launch_taichi_kernel(kernel_name, &ctx);
Expand Down
1 change: 0 additions & 1 deletion taichi/codegen/metal/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ CompiledKernelData run_codegen(
const CompiledRuntimeModule *compiled_runtime_module,
const std::vector<CompiledStructs> &compiled_snode_trees,
Kernel *kernel,
PrintStringTable *print_strtab,
OffloadedStmt *offloaded);

// If |offloaded| is nullptr, this compiles the AST in |kernel|. Otherwise it
Expand Down
4 changes: 2 additions & 2 deletions taichi/runtime/metal/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir,
void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Kernel *kernel) {
auto compiled = run_codegen(compiled_runtime_module_, compiled_snode_trees_,
kernel, &strtab_, /*offloaded=*/nullptr);
kernel, /*offloaded=*/nullptr);
compiled.kernel_name = identifier;
ti_aot_data_.kernels.push_back(std::move(compiled));
}
Expand Down Expand Up @@ -89,7 +89,7 @@ void AotModuleBuilderImpl::add_per_backend_tmpl(const std::string &identifier,
const std::string &key,
Kernel *kernel) {
auto compiled = run_codegen(compiled_runtime_module_, compiled_snode_trees_,
kernel, &strtab_, /*offloaded=*/nullptr);
kernel, /*offloaded=*/nullptr);
for (auto &k : ti_aot_data_.tmpl_kernels) {
if (k.kernel_bundle_name == identifier) {
k.kernel_tmpl_map.insert(std::make_pair(key, compiled));
Expand Down
1 change: 0 additions & 1 deletion taichi/runtime/metal/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
const CompiledRuntimeModule *compiled_runtime_module_;
const std::vector<CompiledStructs> &compiled_snode_trees_;
const std::unordered_set<const SNode *> fields_;
PrintStringTable strtab_;
TaichiAotData ti_aot_data_;
};

Expand Down
4 changes: 1 addition & 3 deletions taichi/runtime/metal/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ class AotModuleImpl : public aot::Module {
return nullptr;
}
auto *kernel_data = itr->second;
runtime_->register_taichi_kernel(
name, kernel_data->source_code, kernel_data->kernel_attribs,
kernel_data->ctx_attribs, /*kernel=*/nullptr);
runtime_->register_taichi_kernel(*kernel_data);
return std::make_unique<KernelImpl>(runtime_, name);
}

Expand Down
81 changes: 29 additions & 52 deletions taichi/runtime/metal/kernel_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,20 +286,20 @@ class CompiledTaichiKernel {
public:
struct Params {
std::string mtl_source_code;
const TaichiKernelAttributes *ti_kernel_attribs;
const KernelContextAttributes *ctx_attribs;
MTLDevice *device;
MemoryPool *mem_pool;
KernelProfilerBase *profiler;
const CompileConfig *compile_config;
const Kernel *kernel;
Device *rhi_device;
const TaichiKernelAttributes *ti_kernel_attribs{nullptr};
const KernelContextAttributes *ctx_attribs{nullptr};
const PrintStringTable *print_str_table{nullptr};
MTLDevice *device{nullptr};
MemoryPool *mem_pool{nullptr};
KernelProfilerBase *profiler{nullptr};
const CompileConfig *compile_config{nullptr};
Device *rhi_device{nullptr};
};

CompiledTaichiKernel(Params params)
: ti_kernel_attribs(*params.ti_kernel_attribs),
ctx_attribs(*params.ctx_attribs),
kernel_(params.kernel),
print_str_table(*params.print_str_table),
rhi_device_(params.rhi_device) {
auto *const device = params.device;
auto kernel_lib = new_library_with_source(
Expand Down Expand Up @@ -420,6 +420,7 @@ class CompiledTaichiKernel {
std::vector<std::unique_ptr<CompiledMtlKernelBase>> compiled_mtl_kernels;
TaichiKernelAttributes ti_kernel_attribs;
KernelContextAttributes ctx_attribs;
PrintStringTable print_str_table;
std::unique_ptr<BufferMemoryView> ctx_mem;
nsobj_unique_ptr<MTLBuffer> ctx_buffer;

Expand All @@ -430,7 +431,6 @@ class CompiledTaichiKernel {
std::unordered_map<int, AllocAndSize> ext_arr_arg_to_dev_alloc;

private:
const Kernel *const kernel_;
Device *const rhi_device_;
};

Expand Down Expand Up @@ -726,34 +726,30 @@ class KernelManager::Impl {
root_buffers_.push_back(std::move(rtbuf));
}

void register_taichi_kernel(const std::string &taichi_kernel_name,
const std::string &mtl_kernel_source_code,
const TaichiKernelAttributes &ti_kernel_attribs,
const KernelContextAttributes &ctx_attribs,
const Kernel *kernel) {
TI_ASSERT(compiled_taichi_kernels_.find(taichi_kernel_name) ==
void register_taichi_kernel(const CompiledKernelData &compiled_kernel) {
TI_ASSERT(compiled_taichi_kernels_.find(compiled_kernel.kernel_name) ==
compiled_taichi_kernels_.end());

if (config_->print_kernel_llvm_ir) {
// If users have enabled |print_kernel_llvm_ir|, it probably means that
// they want to see the compiled code on the given arch. Maybe rename this
// flag, or add another flag (e.g. |print_kernel_source_code|)?
TI_INFO("Metal source code for kernel <{}>\n{}", taichi_kernel_name,
mtl_kernel_source_code);
TI_INFO("Metal source code for kernel <{}>\n{}",
compiled_kernel.kernel_name, compiled_kernel.source_code);
}
CompiledTaichiKernel::Params params;
params.mtl_source_code = mtl_kernel_source_code;
params.ti_kernel_attribs = &ti_kernel_attribs;
params.ctx_attribs = &ctx_attribs;
params.mtl_source_code = compiled_kernel.source_code;
params.ti_kernel_attribs = &compiled_kernel.kernel_attribs;
params.ctx_attribs = &compiled_kernel.ctx_attribs;
params.print_str_table = &compiled_kernel.print_str_table;
params.device = device_.get();
params.mem_pool = mem_pool_;
params.profiler = profiler_;
params.compile_config = config_;
params.kernel = kernel;
params.rhi_device = rhi_device_.get();
compiled_taichi_kernels_[taichi_kernel_name] =
compiled_taichi_kernels_[compiled_kernel.kernel_name] =
std::make_unique<CompiledTaichiKernel>(params);
TI_DEBUG("Registered Taichi kernel <{}>", taichi_kernel_name);
TI_DEBUG("Registered Taichi kernel <{}>", compiled_kernel.kernel_name);
}

void launch_taichi_kernel(const std::string &taichi_kernel_name,
Expand Down Expand Up @@ -812,10 +808,10 @@ class KernelManager::Impl {
ctx_blitter->metal_to_host();
}
if (used.assertion) {
check_assertion_failure();
check_assertion_failure(cti_kernel.print_str_table);
}
if (used.print) {
flush_print_buffers();
flush_print_buffers(cti_kernel.print_str_table);
}
}
}
Expand All @@ -829,10 +825,6 @@ class KernelManager::Impl {
return buffer_meta_data_;
}

PrintStringTable *print_strtable() {
return &print_strtable_;
}

std::size_t get_snode_num_dynamically_allocated(SNode *snode) {
// TODO(k-ye): Have a generic way for querying these sparse runtime stats.
mac::ScopedAutoreleasePool pool;
Expand Down Expand Up @@ -1100,7 +1092,7 @@ class KernelManager::Impl {
// print_runtime_debug();
}

void check_assertion_failure() {
void check_assertion_failure(const PrintStringTable &print_str_table) {
// TODO: Copy this to program's result_buffer, and let the Taichi runtime
// handle the assertion failures uniformly.
auto *asst_rec = reinterpret_cast<shaders::AssertRecorderData *>(
Expand All @@ -1112,7 +1104,7 @@ class KernelManager::Impl {
shaders::PrintMsg msg(msg_ptr, asst_rec->num_args);
using MsgType = shaders::PrintMsg::Type;
TI_ASSERT(msg.pm_get_type(0) == MsgType::Str);
const auto fmt_str = print_strtable_.get(msg.pm_get_data(0));
const auto fmt_str = print_str_table.get(msg.pm_get_data(0));
const auto err_str = format_error_message(fmt_str, [&msg](int argument_id) {
// +1 to skip the first arg, which is the error message template.
const int32 x = msg.pm_get_data(argument_id + 1);
Expand Down Expand Up @@ -1141,7 +1133,7 @@ class KernelManager::Impl {
throw TaichiAssertionError(err_str);
}

void flush_print_buffers() {
void flush_print_buffers(const PrintStringTable &print_str_table) {
auto *pa = reinterpret_cast<shaders::PrintMsgAllocator *>(
print_assert_idevalloc_.mem->ptr() + shaders::kMetalAssertBufferSize);
const int used_sz =
Expand All @@ -1166,7 +1158,7 @@ class KernelManager::Impl {
} else if (dt == MsgType::F32) {
py_cout << *reinterpret_cast<const float *>(&x);
} else if (dt == MsgType::Str) {
py_cout << print_strtable_.get(x);
py_cout << print_str_table.get(x);
} else {
TI_ERROR("Unexpected data type={}", dt);
}
Expand Down Expand Up @@ -1274,7 +1266,6 @@ class KernelManager::Impl {
int last_snode_id_used_in_runtime_{-1};
std::unordered_map<std::string, std::unique_ptr<CompiledTaichiKernel>>
compiled_taichi_kernels_;
PrintStringTable print_strtable_;

// The |dev_*_mirror_|s are the data structures stored in the Metal device
// side that get mirrored to the host side. This is possible because the
Expand All @@ -1301,11 +1292,7 @@ class KernelManager::Impl {
TI_ERROR("Metal not supported on the current OS");
}

void register_taichi_kernel(const std::string &taichi_kernel_name,
const std::string &mtl_kernel_source_code,
const TaichiKernelAttributes &ti_kernel_attribs,
const KernelContextAttributes &ctx_attribs,
const Kernel *kernel) {
void register_taichi_kernel(const CompiledKernelData &) {
TI_ERROR("Metal not supported on the current OS");
}

Expand Down Expand Up @@ -1351,14 +1338,8 @@ void KernelManager::add_compiled_snode_tree(const CompiledStructs &snode_tree) {
impl_->add_compiled_snode_tree(snode_tree);
}

void KernelManager::register_taichi_kernel(
const std::string &taichi_kernel_name,
const std::string &mtl_kernel_source_code,
const TaichiKernelAttributes &ti_kernel_attribs,
const KernelContextAttributes &ctx_attribs,
const Kernel *kernel) {
impl_->register_taichi_kernel(taichi_kernel_name, mtl_kernel_source_code,
ti_kernel_attribs, ctx_attribs, kernel);
void KernelManager::register_taichi_kernel(const CompiledKernelData &compiled) {
impl_->register_taichi_kernel(compiled);
}

void KernelManager::launch_taichi_kernel(const std::string &taichi_kernel_name,
Expand All @@ -1374,10 +1355,6 @@ BufferMetaData KernelManager::get_buffer_meta_data() {
return impl_->get_buffer_meta_data();
}

PrintStringTable *KernelManager::print_strtable() {
return impl_->print_strtable();
}

std::size_t KernelManager::get_snode_num_dynamically_allocated(SNode *snode) {
return impl_->get_snode_num_dynamically_allocated(snode);
}
Expand Down
6 changes: 1 addition & 5 deletions taichi/runtime/metal/kernel_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,7 @@ class KernelManager {
// TODO(k-ye): Remove |taichi_kernel_name| now that it's part of
// |ti_kernel_attribs|. Return a handle that will be passed to
// launch_taichi_kernel(), instead of using kernel name as the identifier.
void register_taichi_kernel(const std::string &taichi_kernel_name,
const std::string &mtl_kernel_source_code,
const TaichiKernelAttributes &ti_kernel_attribs,
const KernelContextAttributes &ctx_attribs,
const Kernel *kernel);
void register_taichi_kernel(const CompiledKernelData &compiled_kernel);

// Launch the given |taichi_kernel_name|.
// Kernel launching is asynchronous, therefore the Metal memory is not valid
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/metal/kernel_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ int PrintStringTable::put(const std::string &str) {
return i;
}

const std::string &PrintStringTable::get(int i) {
const std::string &PrintStringTable::get(int i) const {
return strs_[i];
}

Expand Down
7 changes: 5 additions & 2 deletions taichi/runtime/metal/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ namespace metal {
class PrintStringTable {
public:
int put(const std::string &str);
const std::string &get(int i);
const std::string &get(int i) const;

TI_IO_DEF(strs_);

private:
std::vector<std::string> strs_;
Expand Down Expand Up @@ -279,8 +281,9 @@ struct CompiledKernelData {
std::string source_code;
KernelContextAttributes ctx_attribs;
TaichiKernelAttributes kernel_attribs;
PrintStringTable print_str_table;

TI_IO_DEF(kernel_name, ctx_attribs, kernel_attribs);
TI_IO_DEF(kernel_name, ctx_attribs, kernel_attribs, print_str_table);
};

struct CompiledKernelTmplData {
Expand Down

0 comments on commit c4174e0

Please sign in to comment.