Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ir] Fix compilation crash when offloading LocalStoreStmt and AtomicOpStmt #813

Merged
merged 1 commit into from
Apr 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,15 @@ bool Stmt::have_operand(Stmt *stmt) const {
return false;
}

int Stmt::locate_operand(Stmt **stmt) {
for (int i = 0; i < num_operands(); i++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here should we consider something like $1 = add $0, $0?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we should probably add the tests in.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used Stmt ** rather than Stmt *, so I think the values like the same $0s wouldn't affect.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we should probably add the tests in.

Yeah, we should add #812 (comment) in. But I haven't come up with the name of the test.

(I don't think locate_operand need tests)

if (operands[i] == stmt) {
return i;
}
}
return -1;
}

std::string Expression::get_attribute(const std::string &key) const {
if (auto it = attributes.find(key); it == attributes.end()) {
TI_ERROR("Attribute {} not found.", key);
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ class Stmt : public IRNode {

void set_operand(int i, Stmt *stmt);
void register_operand(Stmt *&stmt);
int locate_operand(Stmt **stmt);
void mark_fields_registered();

virtual void rebuild_operands() {
Expand Down
66 changes: 39 additions & 27 deletions taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <set>
#include <unordered_map>
#include <utility>

#include "taichi/ir/ir.h"

Expand Down Expand Up @@ -345,7 +346,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
const StmtToOffsetMap &local_to_global_offset,
std::unordered_map<Stmt *, Stmt *> stmt_to_offloaded)
: local_to_global_offset(local_to_global_offset),
stmt_to_offloaded(stmt_to_offloaded) {
stmt_to_offloaded(std::move(stmt_to_offloaded)) {
allow_undefined_visitor = true;
invoke_default_visitor = true;
}
Expand Down Expand Up @@ -399,6 +400,8 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
}

void visit(LocalStoreStmt *stmt) override {
if (visit_operand(stmt, stmt->locate_operand(&stmt->data)))
throw IRModified();
TI_ASSERT(stmt->width() == 1);
auto alloca = stmt->ptr;
if (local_to_global_offset.find(alloca) == local_to_global_offset.end())
Expand All @@ -416,6 +419,8 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
}

void visit(AtomicOpStmt *stmt) override {
if (visit_operand(stmt, stmt->locate_operand(&stmt->val)))
throw IRModified();
TI_ASSERT(stmt->width() == 1);
auto alloca = stmt->dest;
if (local_to_global_offset.find(alloca) == local_to_global_offset.end())
Expand All @@ -432,38 +437,45 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
throw IRModified();
}

bool visit_operand(Stmt *stmt, int index) {
// return true if modified
TI_ASSERT(index >= 0 && index < stmt->num_operands());
auto op = stmt->operand(index);
if (op == nullptr)
return false;
if (stmt_to_offloaded[stmt] ==
stmt_to_offloaded[op]) // same OffloadedStmt
return false;
if (advanced_optimization) {
if (op->is<ConstStmt>()) {
auto copy = op->as<ConstStmt>()->copy();
stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt];
stmt->set_operand(index, copy.get());
stmt->insert_before_me(std::move(copy));
return true;
}
}
if (local_to_global_offset.find(op) == local_to_global_offset.end())
return false;

auto global = Stmt::make<GlobalTemporaryStmt>(local_to_global_offset[op],
op->ret_type);
auto load = Stmt::make<GlobalLoadStmt>(global.get());
stmt_to_offloaded[load.get()] = stmt_to_offloaded[stmt];
stmt->set_operand(index, load.get());
stmt->insert_before_me(std::move(global));
stmt->insert_before_me(std::move(load));
return true;
}

// Generic visitor
void visit(Stmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
int n_op = stmt->num_operands();
bool modified = false;
for (int i = 0; i < n_op; i++) {
auto op = stmt->operand(i);
if (op == nullptr)
continue;
if (stmt_to_offloaded[stmt] ==
stmt_to_offloaded[op]) // same OffloadedStmt
continue;
if (advanced_optimization) {
if (op->is<ConstStmt>()) {
auto copy = op->as<ConstStmt>()->copy();
stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt];
stmt->set_operand(i, copy.get());
stmt->insert_before_me(std::move(copy));
modified = true;
continue;
}
}
if (local_to_global_offset.find(op) == local_to_global_offset.end())
continue;

auto global = Stmt::make<GlobalTemporaryStmt>(local_to_global_offset[op],
op->ret_type);
auto load = Stmt::make<GlobalLoadStmt>(global.get());
stmt->set_operand(i, load.get());
stmt->insert_before_me(std::move(global));
stmt->insert_before_me(std::move(load));
modified = true;
if (visit_operand(stmt, i))
modified = true;
}
if (modified)
throw IRModified();
Expand Down