Skip to content

Commit

Permalink
[Bug] [ir] Fix the IdentifyValuesUsedInOtherOffloads pass (#3597)
Browse files Browse the repository at this point in the history
  • Loading branch information
strongoier committed Nov 23, 2021
1 parent a13c37d commit 26b833e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
17 changes: 9 additions & 8 deletions taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,6 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor {
TI_ASSERT(current_offloaded);
}

void visit(RangeForStmt *stmt) override {
test_and_allocate(stmt->begin);
test_and_allocate(stmt->end);
if (stmt->body)
stmt->body->accept(this);
}

void test_and_allocate(Stmt *stmt) {
if (stmt == nullptr)
return;
Expand All @@ -371,14 +364,22 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor {
}
}

void visit(Stmt *stmt) override {
void generic_visit(Stmt *stmt) {
int n_op = stmt->num_operands();
for (int i = 0; i < n_op; i++) {
auto op = stmt->operand(i);
test_and_allocate(op);
}
}

void preprocess_container_stmt(Stmt *stmt) override {
generic_visit(stmt);
}

void visit(Stmt *stmt) override {
generic_visit(stmt);
}

static StmtToOffsetMap run(
IRNode *root,
const CompileConfig &config,
Expand Down
12 changes: 12 additions & 0 deletions tests/python/test_offload_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,15 @@ def run(a: ti.i32):
print('OK')

run(2)


@ti.test()
def test_offload_with_cross_if_inside_for():
@ti.kernel
def run(a: ti.i32):
b = a > 2
for x in range(1):
if b:
print('OK')

run(2)

0 comments on commit 26b833e

Please sign in to comment.