Skip to content

Commit

Permalink
[refactor] Remove dependencies on Callable::program in lang::get_hash…
Browse files Browse the repository at this point in the history
…ed_offline_cache_key (taichi-dev#7287)

Issue: taichi-dev#7286
* taichi-dev#7286: Part of taichi-dev#7002

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent b9e15a0 commit 400a123
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 18 deletions.
21 changes: 9 additions & 12 deletions taichi/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
using IRVisitor::visit;

public:
ASTSerializer(Program *prog, std::ostream *os)
: ExpressionVisitor(true), prog_(prog), os_(os) {
explicit ASTSerializer(std::ostream *os) : ExpressionVisitor(true), os_(os) {
// TODO(PGZXB): Set allow_undefined_visitor as false. (blocked by
// constant-folding)
this->allow_undefined_visitor = true;
Expand Down Expand Up @@ -414,8 +413,8 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
emit(stmt->outputs);
}

static void run(Program *prog, IRNode *ast, std::ostream *os) {
ASTSerializer serializer(prog, os);
static void run(IRNode *ast, std::ostream *os) {
ASTSerializer serializer(os);
ast->accept(&serializer);
serializer.emit_dependencies();
}
Expand All @@ -434,7 +433,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
// Serialize snode_trees(Temporary: using offline-cache-key of SNode)
// Note: The result of serializing snode_tree_roots_ is not parsable now
emit(static_cast<std::size_t>(snode_tree_roots_.size()));
for (auto *snode : snode_tree_roots_) {
for (const auto *snode : snode_tree_roots_) {
auto key = get_hashed_offline_cache_key_of_snode(snode);
emit_bytes(key.c_str(), key.size());
}
Expand Down Expand Up @@ -517,12 +516,11 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
}
}

void emit(SNode *snode) {
TI_ASSERT(prog_);
void emit(const SNode *snode) {
if (snode) {
emit(static_cast<std::size_t>(snode->get_snode_tree_id()));
emit(static_cast<std::size_t>(snode->id));
auto *root = prog_->get_snode_root(snode->get_snode_tree_id());
const auto *root = snode->get_root();
snode_tree_roots_.insert(root);
} else {
emit(std::numeric_limits<std::size_t>::max());
Expand Down Expand Up @@ -643,17 +641,16 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {

#undef DEFINE_EMIT_ENUM

Program *prog_{nullptr};
std::ostream *os_{nullptr};
std::unordered_set<SNode *> snode_tree_roots_;
std::unordered_set<const SNode *> snode_tree_roots_;
std::unordered_map<Function *, std::size_t> real_funcs_;
std::vector<char> string_pool_;
};

} // namespace

void gen_offline_cache_key(Program *prog, IRNode *ast, std::ostream *os) {
ASTSerializer::run(prog, ast, os);
void gen_offline_cache_key(IRNode *ast, std::ostream *os) {
ASTSerializer::run(ast, os);
}

} // namespace taichi::lang
6 changes: 3 additions & 3 deletions taichi/analysis/offline_cache_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ static std::vector<std::uint8_t> get_offline_cache_key_of_compile_config(
}

static void get_offline_cache_key_of_snode_impl(
SNode *snode,
const SNode *snode,
BinaryOutputSerializer &serializer,
std::unordered_set<int> &visited) {
if (auto iter = visited.find(snode->id); iter != visited.end()) {
Expand Down Expand Up @@ -122,7 +122,7 @@ static void get_offline_cache_key_of_snode_impl(
serializer(snode->get_snode_tree_id());
}

std::string get_hashed_offline_cache_key_of_snode(SNode *snode) {
std::string get_hashed_offline_cache_key_of_snode(const SNode *snode) {
TI_ASSERT(snode);

BinaryOutputSerializer serializer;
Expand All @@ -145,7 +145,7 @@ std::string get_hashed_offline_cache_key(const CompileConfig &config,
std::string kernel_ast_string;
if (kernel) {
std::ostringstream oss;
gen_offline_cache_key(kernel->program, kernel->ir.get(), &oss);
gen_offline_cache_key(kernel->ir.get(), &oss);
kernel_ast_string = oss.str();
}

Expand Down
4 changes: 2 additions & 2 deletions taichi/analysis/offline_cache_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ class IRNode;
class SNode;
class Kernel;

std::string get_hashed_offline_cache_key_of_snode(SNode *snode);
std::string get_hashed_offline_cache_key_of_snode(const SNode *snode);
std::string get_hashed_offline_cache_key(const CompileConfig &config,
Kernel *kernel);
void gen_offline_cache_key(Program *prog, IRNode *ast, std::ostream *os);
void gen_offline_cache_key(IRNode *ast, std::ostream *os);

} // namespace taichi::lang
7 changes: 7 additions & 0 deletions taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,4 +330,11 @@ int SNode::get_snode_tree_id() const {
return snode_tree_id_;
}

const SNode *SNode::get_root() const {
if (!parent) { // root->parent == nullptr
return this;
}
return parent->get_root();
}

} // namespace taichi::lang
2 changes: 2 additions & 0 deletions taichi/ir/snode.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ class SNode {

int get_snode_tree_id() const;

const SNode *get_root() const;

static void reset_counter() {
counter = 0;
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/program/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void Function::set_function_body(const std::function<void()> &func) {

if (program->compile_config().offline_cache) { // For generating AST-Key
std::ostringstream oss;
gen_offline_cache_key(program, ir.get(), &oss);
gen_offline_cache_key(ir.get(), &oss);
ast_serialization_data_ = oss.str();
}
}
Expand Down

0 comments on commit 400a123

Please sign in to comment.