Skip to content

Commit

Permalink
[opt] Avoid storing constants across offloaded tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
xumingkuan authored and yuanming-hu committed Apr 17, 2020
1 parent b79d93d commit 7d3bd6c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
4 changes: 4 additions & 0 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,10 @@ void SNodeOpExpression::flatten(VecStatement &ret) {
stmt = ret.back().get();
}

std::unique_ptr<ConstStmt> ConstStmt::copy() {
return std::make_unique<ConstStmt>(val);
}

For::For(const Expr &s, const Expr &e, const std::function<void(Expr)> &func) {
auto i = Expr(std::make_shared<IdExpression>());
auto stmt_unique = std::make_unique<FrontendForStmt>(i, s, e);
Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1622,6 +1622,8 @@ class ConstStmt : public Stmt {
return false;
}

std::unique_ptr<ConstStmt> copy();

TI_STMT_DEF_FIELDS(ret_type, val);
DEFINE_ACCEPT
};
Expand Down
16 changes: 14 additions & 2 deletions taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor {
return;
if (stmt_to_offloaded[stmt] == current_offloaded)
return;
if (stmt->is<ConstStmt>()) {
// Directly insert copies of ConstStmts later
return;
}
if (local_to_global.find(stmt) == local_to_global.end()) {
// Not yet allocated
local_to_global[stmt] = allocate_global(stmt->ret_type);
Expand Down Expand Up @@ -435,11 +439,19 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
auto op = stmt->operand(i);
if (op == nullptr)
continue;
if (local_to_global_offset.find(op) == local_to_global_offset.end())
continue;
if (stmt_to_offloaded[stmt] ==
stmt_to_offloaded[op]) // same OffloadedStmt
continue;
if (op->is<ConstStmt>()) {
auto copy = op->as<ConstStmt>()->copy();
stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt];
stmt->set_operand(i, copy.get());
stmt->insert_before_me(std::move(copy));
modified = true;
continue;
}
if (local_to_global_offset.find(op) == local_to_global_offset.end())
continue;

auto global = Stmt::make<GlobalTemporaryStmt>(local_to_global_offset[op],
op->ret_type);
Expand Down

0 comments on commit 7d3bd6c

Please sign in to comment.