Skip to content

Commit

Permalink
Refactor2023: Split lang::Function
Browse files Browse the repository at this point in the history
  • Loading branch information
PGZXB committed Dec 12, 2022
1 parent 2760e61 commit d87e6d9
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 21 deletions.
12 changes: 9 additions & 3 deletions taichi/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,16 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
// Serialize dependent real-functions
emit(real_funcs_.size());
for (auto &[func, id] : real_funcs_) {
if (auto &ast_str = func->try_get_ast_serialization_data();
ast_str.has_value()) {
emit_bytes(ast_str->c_str(), ast_str->size());
if (const auto &ast_str = func->try_get_ast_serialization_data();
!ast_str.has_value()) {
func->set_ast_serialization_data("\xff");
std::ostringstream oss;
gen_offline_cache_key(prog_, func->ir.get(), &oss);
func->set_ast_serialization_data(oss.str());
}
const auto &ast_str = func->try_get_ast_serialization_data();
TI_ASSERT(ast_str.has_value() && !ast_str->empty());
emit_bytes(ast_str->c_str(), ast_str->size());
}

// Serialize snode_trees(Temporary: using offline-cache-key of SNode)
Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ void compile_function(IRNode *ir,
void ast_to_ir(const CompileConfig &config,
/*FIXME:Fix to const */ Kernel &kernel,
bool to_executable = true);

void lower_called_functions(const CompileConfig &compile_config, IRNode *ir);
} // namespace irpass

} // namespace taichi::lang
19 changes: 2 additions & 17 deletions taichi/program/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,14 @@ Function::Function(Program *program, const FunctionKey &func_key)
void Function::set_function_body(const std::function<void()> &func) {
context = std::make_unique<FrontendContext>();
ir = context->get_root();
ir_start_from_ast_ = true;

func();

if (program->global_compile_config()
.offline_cache) { // For generating AST-Key
std::ostringstream oss;
gen_offline_cache_key(program, ir.get(), &oss);
ast_serialization_data_ = oss.str();
}
irpass::compile_function(
ir.get(), program->global_compile_config(), this,
/*autodiff_mode=*/AutodiffMode::kNone,
/*verbose=*/program->global_compile_config().print_ir,
/*start_from_ast=*/true);
}

void Function::set_function_body(std::unique_ptr<IRNode> func_body) {
ir = std::move(func_body);
irpass::compile_function(
ir.get(), program->global_compile_config(), this,
/*autodiff_mode=*/AutodiffMode::kNone,
/*verbose=*/program->global_compile_config().print_ir,
/*start_from_ast=*/false);
ir_start_from_ast_ = false;
}

std::string Function::get_name() const {
Expand Down
20 changes: 19 additions & 1 deletion taichi/program/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,29 @@ class Function : public Callable {

[[nodiscard]] std::string get_name() const override;

std::optional<std::string> &try_get_ast_serialization_data() {
const std::optional<std::string> &try_get_ast_serialization_data() const {
return ast_serialization_data_;
}

void set_ast_serialization_data(std::string ast_data) {
ast_serialization_data_ = std::move(ast_data);
}

bool lowered() const {
return lowered_;
}

void set_lowered(bool lowered) {
lowered_ = lowered;
}

bool ir_start_from_ast() const {
return ir_start_from_ast_;
}

private:
bool ir_start_from_ast_{false}; // Refactor2023:FIXME: Remove it
bool lowered_{false};
std::optional<std::string> ast_serialization_data_; // For generating AST-Key
};

Expand Down
3 changes: 3 additions & 0 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ void compile_to_offloads(IRNode *ir,
print("Lowered");
}


irpass::lower_called_functions(config, ir);

if (config.real_matrix && config.real_matrix_scalarize) {
irpass::scalarize(ir, config.dynamic_index);

Expand Down
47 changes: 47 additions & 0 deletions taichi/transforms/lower_called_functions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"
#include "taichi/ir/statements.h"
#include "taichi/program/function.h"
#include "taichi/program/compile_config.h"

namespace taichi::lang {

class LowerCalledFunctions : public BasicStmtVisitor {
public:
using BasicStmtVisitor::visit;

explicit LowerCalledFunctions(const CompileConfig &compile_config)
: compile_config_(compile_config) {
}

void visit(FuncCallStmt *stmt) override {
auto *func = stmt->func;
if (!func->lowered()) {
func->set_lowered(true);
irpass::compile_function(func->ir.get(), compile_config_, func,
/*autodiff_mode=*/AutodiffMode::kNone,
/*verbose=*/compile_config_.print_ir,
/*start_from_ast=*/func->ir_start_from_ast());
func->ir->accept(this);
}
}

static void run(const CompileConfig &compile_config, IRNode *root) {
LowerCalledFunctions lcf{compile_config};
root->accept(&lcf);
}

private:
const CompileConfig &compile_config_;
};

namespace irpass {

void lower_called_functions(const CompileConfig &compile_config, IRNode *root) {
TI_AUTO_PROF;
LowerCalledFunctions::run(compile_config, root);
}

} // namespace irpass

} // namespace taichi::lang

0 comments on commit d87e6d9

Please sign in to comment.