Skip to content

Commit

Permalink
Refactor2023: Remove dependencies on Program::this_thread_config() in…
Browse files Browse the repository at this point in the history
… irpass::scalarize
  • Loading branch information
PGZXB committed Dec 7, 2022
1 parent b8deaf0 commit 40d4b33
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace irpass {

void re_id(IRNode *root);
void flag_access(IRNode *root);
void scalarize(IRNode *root);
void scalarize(IRNode *root, bool dynamic_index);
void lower_matrix_ptr(IRNode *root);
bool die(IRNode *root);
bool simplify(IRNode *root, const CompileConfig &config);
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void compile_to_offloads(IRNode *ir,
}

if (config.real_matrix && config.real_matrix_scalarize) {
irpass::scalarize(ir);
irpass::scalarize(ir, config.dynamic_index);

// Remove redundant MatrixInitStmt inserted during scalarization
irpass::die(ir);
Expand Down
5 changes: 2 additions & 3 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,11 +612,10 @@ class ScalarizePointers : public BasicStmtVisitor {
};

namespace irpass {

void scalarize(IRNode *root) {
void scalarize(IRNode *root, bool dynamic_index) {
TI_AUTO_PROF;
Scalarize scalarize_pass(root);
if (!root->get_kernel()->program->this_thread_config().dynamic_index) {
if (!dynamic_index) {
ScalarizePointers scalarize_pointers_pass(root);
}
}
Expand Down
12 changes: 8 additions & 4 deletions tests/cpp/transforms/scalarize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ TEST(Scalarize, ScalarizeGlobalStore) {

block->push_back<GlobalStoreStmt>(dest_stmt, matrix_init_stmt);

irpass::scalarize(block.get());
irpass::scalarize(block.get(),
test_prog.prog()->this_thread_config().dynamic_index);
irpass::lower_matrix_ptr(block.get());
irpass::die(block.get());

Expand Down Expand Up @@ -102,7 +103,8 @@ TEST(Scalarize, ScalarizeGlobalLoad) {
// Without this GlobalStoreStmt, nothing survives irpass::die()
block->push_back<GlobalStoreStmt>(src_stmt, load_stmt);

irpass::scalarize(block.get());
irpass::scalarize(block.get(),
test_prog.prog()->this_thread_config().dynamic_index);
irpass::lower_matrix_ptr(block.get());
irpass::die(block.get());

Expand Down Expand Up @@ -163,7 +165,8 @@ TEST(Scalarize, ScalarizeLocalStore) {
// LocalStoreStmt survives irpass::die()
block->push_back<LocalStoreStmt>(dest_stmt, matrix_init_stmt);

irpass::scalarize(block.get());
irpass::scalarize(block.get(),
test_prog.prog()->this_thread_config().dynamic_index);
irpass::die(block.get());

EXPECT_EQ(block->size(), 2 /*const*/ + 4 /*alloca*/ + 4 /*store*/);
Expand Down Expand Up @@ -211,7 +214,8 @@ TEST(Scalarize, ScalarizeLocalLoad) {
// Without this GlobalStoreStmt, nothing survives irpass::die()
block->push_back<GlobalStoreStmt>(src_stmt, load_stmt);

irpass::scalarize(block.get());
irpass::scalarize(block.get(),
test_prog.prog()->this_thread_config().dynamic_index);
irpass::die(block.get());

EXPECT_EQ(block->size(), 4 /*alloca*/ + 4 /*load*/ + 4 /*store*/);
Expand Down

0 comments on commit 40d4b33

Please sign in to comment.