Skip to content

Commit

Permalink
[opt] [ir] [refactor] Remove exceptions from offload pass (#3925)
Browse files Browse the repository at this point in the history
* call recursively generic_visit

* remove redundant catch code

* remove another `catch`

* fix
  • Loading branch information
mzmzm committed Jan 3, 2022
1 parent 6056b49 commit f60fcde
Showing 1 changed file with 7 additions and 25 deletions.
32 changes: 7 additions & 25 deletions taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,20 +464,12 @@ class PromoteIntermediateToGlobalTmp : public BasicStmtVisitor {
auto ptr = stmt->insert_after_me(
Stmt::make<GlobalTemporaryStmt>(offset, stmt->ret_type));
ptr->insert_after_me(Stmt::make<GlobalStoreStmt>(ptr, stmt));
throw IRModified();
}
}

static void run(IRNode *root, const StmtToOffsetMap &local_to_global_offset) {
PromoteIntermediateToGlobalTmp pass(local_to_global_offset);
while (true) {
try {
root->accept(&pass);
} catch (IRModified) {
continue;
}
break;
}
root->accept(&pass);
}

private:
Expand Down Expand Up @@ -577,7 +569,6 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
stmt->parent->replace_with(stmt, std::move(replacement), false);
// To deal with the same offloaded visit_operand()
stmt_to_offloaded_[stmt] = nullptr;
throw IRModified();
}

// Replace local LD/ST with global LD/ST
Expand All @@ -591,7 +582,6 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
auto global_load = replacement.push_back<GlobalLoadStmt>(ptr);
stmt_to_offloaded_[global_load] = stmt_to_offloaded_[stmt];
stmt->parent->replace_with(stmt, std::move(replacement));
throw IRModified();
}
}

Expand All @@ -605,7 +595,6 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
replacement.push_back<GlobalStoreStmt>(ptr, stmt->val);
stmt_to_offloaded_[global_store] = stmt_to_offloaded_[stmt];
stmt->parent->replace_with(stmt, std::move(replacement));
throw IRModified();
}
}

Expand All @@ -623,10 +612,12 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {

if (op->is<GlobalPtrStmt>()) {
auto copy = op->clone();
auto pcopy = copy.get();
copy->as<GlobalPtrStmt>()->activate = false;
stmt_to_offloaded_[copy.get()] = offloaded;
stmt->set_operand(index, copy.get());
stmt->insert_before_me(std::move(copy));
generic_visit(pcopy);
return true;
}

Expand All @@ -638,9 +629,11 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
"{} is not allowed here.", op->type());
// For cases like ConstStmt
auto copy = op->clone();
auto pcopy = copy.get();
stmt_to_offloaded_[copy.get()] = offloaded;
stmt->set_operand(index, copy.get());
stmt->insert_before_me(std::move(copy));
generic_visit(pcopy);
} else {
auto global_temporary = Stmt::make<GlobalTemporaryStmt>(
local_to_global_offset_[op], op->ret_type);
Expand All @@ -664,13 +657,9 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {

void generic_visit(Stmt *stmt) {
int n_op = stmt->num_operands();
bool modified = false;
for (int i = 0; i < n_op; i++) {
if (visit_operand(stmt, i))
modified = true;
visit_operand(stmt, i);
}
if (modified)
throw IRModified();
}

void visit(Stmt *stmt) override {
Expand All @@ -690,14 +679,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
OffloadedRanges *offloaded_ranges) {
FixCrossOffloadReferences pass(config, local_to_global_offset,
stmt_to_offloaded, offloaded_ranges);
while (true) {
try {
root->accept(&pass);
} catch (IRModified) {
continue;
}
break;
}
root->accept(&pass);
}

private:
Expand Down

0 comments on commit f60fcde

Please sign in to comment.