Skip to content

Commit

Permalink
[ir] Fix compilation crash when offloading LocalStoreStmt and AtomicO…
Browse files Browse the repository at this point in the history
…pStmt
  • Loading branch information
xumingkuan committed Apr 18, 2020
1 parent f76b098 commit 9ec538a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 27 deletions.
9 changes: 9 additions & 0 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,15 @@ bool Stmt::have_operand(Stmt *stmt) const {
return false;
}

int Stmt::locate_operand(Stmt **stmt) {
for (int i = 0; i < num_operands(); i++) {
if (operands[i] == stmt) {
return i;
}
}
return -1;
}

std::string Expression::get_attribute(const std::string &key) const {
if (auto it = attributes.find(key); it == attributes.end()) {
TI_ERROR("Attribute {} not found.", key);
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ class Stmt : public IRNode {

void set_operand(int i, Stmt *stmt);
void register_operand(Stmt *&stmt);
int locate_operand(Stmt **stmt);
void mark_fields_registered();

virtual void rebuild_operands() {
Expand Down
66 changes: 39 additions & 27 deletions taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <set>
#include <unordered_map>
#include <utility>

#include "taichi/ir/ir.h"

Expand Down Expand Up @@ -345,7 +346,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
const StmtToOffsetMap &local_to_global_offset,
std::unordered_map<Stmt *, Stmt *> stmt_to_offloaded)
: local_to_global_offset(local_to_global_offset),
stmt_to_offloaded(stmt_to_offloaded) {
stmt_to_offloaded(std::move(stmt_to_offloaded)) {
allow_undefined_visitor = true;
invoke_default_visitor = true;
}
Expand Down Expand Up @@ -399,6 +400,8 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
}

void visit(LocalStoreStmt *stmt) override {
if (visit_operand(stmt, stmt->locate_operand(&stmt->data)))
throw IRModified();
TI_ASSERT(stmt->width() == 1);
auto alloca = stmt->ptr;
if (local_to_global_offset.find(alloca) == local_to_global_offset.end())
Expand All @@ -416,6 +419,8 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
}

void visit(AtomicOpStmt *stmt) override {
if (visit_operand(stmt, stmt->locate_operand(&stmt->val)))
throw IRModified();
TI_ASSERT(stmt->width() == 1);
auto alloca = stmt->dest;
if (local_to_global_offset.find(alloca) == local_to_global_offset.end())
Expand All @@ -432,38 +437,45 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
throw IRModified();
}

bool visit_operand(Stmt *stmt, int index) {
// return true if modified
TI_ASSERT(index >= 0 && index < stmt->num_operands());
auto op = stmt->operand(index);
if (op == nullptr)
return false;
if (stmt_to_offloaded[stmt] ==
stmt_to_offloaded[op]) // same OffloadedStmt
return false;
if (advanced_optimization) {
if (op->is<ConstStmt>()) {
auto copy = op->as<ConstStmt>()->copy();
stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt];
stmt->set_operand(index, copy.get());
stmt->insert_before_me(std::move(copy));
return true;
}
}
if (local_to_global_offset.find(op) == local_to_global_offset.end())
return false;

auto global = Stmt::make<GlobalTemporaryStmt>(local_to_global_offset[op],
op->ret_type);
auto load = Stmt::make<GlobalLoadStmt>(global.get());
stmt_to_offloaded[load.get()] = stmt_to_offloaded[stmt];
stmt->set_operand(index, load.get());
stmt->insert_before_me(std::move(global));
stmt->insert_before_me(std::move(load));
return true;
}

// Generic visitor
void visit(Stmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
int n_op = stmt->num_operands();
bool modified = false;
for (int i = 0; i < n_op; i++) {
auto op = stmt->operand(i);
if (op == nullptr)
continue;
if (stmt_to_offloaded[stmt] ==
stmt_to_offloaded[op]) // same OffloadedStmt
continue;
if (advanced_optimization) {
if (op->is<ConstStmt>()) {
auto copy = op->as<ConstStmt>()->copy();
stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt];
stmt->set_operand(i, copy.get());
stmt->insert_before_me(std::move(copy));
modified = true;
continue;
}
}
if (local_to_global_offset.find(op) == local_to_global_offset.end())
continue;

auto global = Stmt::make<GlobalTemporaryStmt>(local_to_global_offset[op],
op->ret_type);
auto load = Stmt::make<GlobalLoadStmt>(global.get());
stmt->set_operand(i, load.get());
stmt->insert_before_me(std::move(global));
stmt->insert_before_me(std::move(load));
modified = true;
if (visit_operand(stmt, i))
modified = true;
}
if (modified)
throw IRModified();
Expand Down

0 comments on commit 9ec538a

Please sign in to comment.