From d87e6d903cdbe10c28311486e877f134b1ad56fc Mon Sep 17 00:00:00 2001 From: PGZXB Date: Tue, 13 Dec 2022 00:01:05 +0800 Subject: [PATCH] Refactor2023: Split lang::Function --- taichi/analysis/gen_offline_cache_key.cpp | 12 +++-- taichi/ir/transforms.h | 2 + taichi/program/function.cpp | 19 +------- taichi/program/function.h | 20 ++++++++- taichi/transforms/compile_to_offloads.cpp | 3 ++ taichi/transforms/lower_called_functions.cpp | 47 ++++++++++++++++++++ 6 files changed, 82 insertions(+), 21 deletions(-) create mode 100644 taichi/transforms/lower_called_functions.cpp diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index a133d474adee7..ace282c0363fd 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -425,10 +425,16 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { // Serialize dependent real-functions emit(real_funcs_.size()); for (auto &[func, id] : real_funcs_) { - if (auto &ast_str = func->try_get_ast_serialization_data(); - ast_str.has_value()) { - emit_bytes(ast_str->c_str(), ast_str->size()); + if (const auto &ast_str = func->try_get_ast_serialization_data(); + !ast_str.has_value()) { + func->set_ast_serialization_data("\xff"); + std::ostringstream oss; + gen_offline_cache_key(prog_, func->ir.get(), &oss); + func->set_ast_serialization_data(oss.str()); } + const auto &ast_str = func->try_get_ast_serialization_data(); + TI_ASSERT(ast_str.has_value() && !ast_str->empty()); + emit_bytes(ast_str->c_str(), ast_str->size()); } // Serialize snode_trees(Temporary: using offline-cache-key of SNode) diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index 69cc38c2237bb..bc18b703ea848 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -198,6 +198,8 @@ void compile_function(IRNode *ir, void ast_to_ir(const CompileConfig &config, /*FIXME:Fix to const */ Kernel &kernel, bool to_executable = true); + +void lower_called_functions(const CompileConfig &compile_config, IRNode *ir); } // namespace irpass } // namespace taichi::lang diff --git a/taichi/program/function.cpp b/taichi/program/function.cpp index 0085a760bd67b..e220f585aced1 100644 --- a/taichi/program/function.cpp +++ b/taichi/program/function.cpp @@ -13,29 +13,14 @@ Function::Function(Program *program, const FunctionKey &func_key) void Function::set_function_body(const std::function &func) { context = std::make_unique(); ir = context->get_root(); + ir_start_from_ast_ = true; func(); - - if (program->global_compile_config() - .offline_cache) { // For generating AST-Key - std::ostringstream oss; - gen_offline_cache_key(program, ir.get(), &oss); - ast_serialization_data_ = oss.str(); - } - irpass::compile_function( - ir.get(), program->global_compile_config(), this, - /*autodiff_mode=*/AutodiffMode::kNone, - /*verbose=*/program->global_compile_config().print_ir, - /*start_from_ast=*/true); } void Function::set_function_body(std::unique_ptr func_body) { ir = std::move(func_body); - irpass::compile_function( - ir.get(), program->global_compile_config(), this, - /*autodiff_mode=*/AutodiffMode::kNone, - /*verbose=*/program->global_compile_config().print_ir, - /*start_from_ast=*/false); + ir_start_from_ast_ = false; } std::string Function::get_name() const { diff --git a/taichi/program/function.h b/taichi/program/function.h index 2da23219d6c0c..85d8e88e1f9e9 100644 --- a/taichi/program/function.h +++ b/taichi/program/function.h @@ -22,11 +22,29 @@ class Function : public Callable { [[nodiscard]] std::string get_name() const override; - std::optional &try_get_ast_serialization_data() { + const std::optional &try_get_ast_serialization_data() const { return ast_serialization_data_; } + void set_ast_serialization_data(std::string ast_data) { + ast_serialization_data_ = std::move(ast_data); + } + + bool lowered() const { + return lowered_; + } + + void set_lowered(bool lowered) { + lowered_ = lowered; + } + + bool ir_start_from_ast() const { + return ir_start_from_ast_; + } + private: + bool ir_start_from_ast_{false}; // Refactor2023:FIXME: Remove it + bool lowered_{false}; std::optional ast_serialization_data_; // For generating AST-Key }; diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index dd7856f291e99..c004b15d7d502 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -61,6 +61,9 @@ void compile_to_offloads(IRNode *ir, print("Lowered"); } + + irpass::lower_called_functions(config, ir); + if (config.real_matrix && config.real_matrix_scalarize) { irpass::scalarize(ir, config.dynamic_index); diff --git a/taichi/transforms/lower_called_functions.cpp b/taichi/transforms/lower_called_functions.cpp new file mode 100644 index 0000000000000..6040e7563a034 --- /dev/null +++ b/taichi/transforms/lower_called_functions.cpp @@ -0,0 +1,47 @@ +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" +#include "taichi/ir/statements.h" +#include "taichi/program/function.h" +#include "taichi/program/compile_config.h" + +namespace taichi::lang { + +class LowerCalledFunctions : public BasicStmtVisitor { + public: + using BasicStmtVisitor::visit; + + explicit LowerCalledFunctions(const CompileConfig &compile_config) + : compile_config_(compile_config) { + } + + void visit(FuncCallStmt *stmt) override { + auto *func = stmt->func; + if (!func->lowered()) { + func->set_lowered(true); + irpass::compile_function(func->ir.get(), compile_config_, func, + /*autodiff_mode=*/AutodiffMode::kNone, + /*verbose=*/compile_config_.print_ir, + /*start_from_ast=*/func->ir_start_from_ast()); + func->ir->accept(this); + } + } + + static void run(const CompileConfig &compile_config, IRNode *root) { + LowerCalledFunctions lcf{compile_config}; + root->accept(&lcf); + } + + private: + const CompileConfig &compile_config_; +}; + +namespace irpass { + +void lower_called_functions(const CompileConfig &compile_config, IRNode *root) { + TI_AUTO_PROF; + LowerCalledFunctions::run(compile_config, root); +} + +} // namespace irpass + +} // namespace taichi::lang