Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[opt] Simplify replace_statements and improve demote_dense_struct_fors #2335

Merged
merged 2 commits into from
May 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,30 @@ bool constant_fold(IRNode *root,
const CompileConfig &config,
const ConstantFoldPass::Args &args);
void offload(IRNode *root, const CompileConfig &config);
bool transform_statements(
IRNode *root,
std::function<bool(Stmt *)> filter,
std::function<void(Stmt *, DelayedIRModifier *)> transformer);
/**
* @param root The IR root to be traversed.
* @param filter A function which tells if a statement need to be replaced.
* @param generator If a statement s need to be replaced, generate a new
* statement s1 with the argument s, insert s1 to s's place, and replace all
* usages of s with s1.
* @param generator If a statement |s| need to be replaced, generate a new
* statement |s1| with the argument |s|, insert |s1| to where |s| is defined,
* remove |s|'s definition, and replace all usages of |s| with |s1|.
* @return Whether the IR is modified.
*/
void replace_statements_with(
bool replace_and_insert_statements(
IRNode *root,
std::function<bool(Stmt *)> filter,
std::function<std::unique_ptr<Stmt>(Stmt *)> generator);
/**
* @param generator If a statement s need to be replaced, find the existing
* statement s1 with the argument s, and replace all usages of s with s1.
* @return Whether the IR is modified.
* @param finder If a statement |s| need to be replaced, find the existing
* statement |s1| with the argument |s|, remove |s|'s definition, and replace
* all usages of |s| with |s1|.
*/
bool replace_statements_with(IRNode *root,
std::function<bool(Stmt *)> filter,
std::function<Stmt *(Stmt *)> generator);
bool replace_statements(IRNode *root,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the above comment, how is |s|'s defining stmt handled in this case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

|s|'s defining stmt is erased here.

std::function<bool(Stmt *)> filter,
std::function<Stmt *(Stmt *)> finder);
void demote_dense_struct_fors(IRNode *root);
bool demote_atomics(IRNode *root, const CompileConfig &config);
void reverse_segments(IRNode *root); // for autograd
Expand Down
29 changes: 15 additions & 14 deletions taichi/transforms/demote_dense_struct_fors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,22 @@ void convert_to_range_for(OffloadedStmt *offloaded) {
}
}

for (int i = 0; i < num_loop_vars; i++) {
// TODO: Use only one (instead num_loop_vars) invocation(s) of
// replace_statements_with
irpass::replace_statements_with(
body.get(),
[&](Stmt *s) {
if (auto loop_index = s->cast<LoopIndexStmt>()) {
return loop_index->loop == offloaded &&
loop_index->index ==
snodes.back()->physical_index_position[i];
}
irpass::replace_statements(
body.get(), /*filter=*/
[&](Stmt *s) {
if (auto loop_index = s->cast<LoopIndexStmt>()) {
return loop_index->loop == offloaded;
} else {
return false;
},
[&](Stmt *) { return new_loop_vars[i]; });
}
}
},
/*finder=*/
[&](Stmt *s) {
auto index = std::find(physical_indices.begin(), physical_indices.end(),
s->as<LoopIndexStmt>()->index);
TI_ASSERT(index != physical_indices.end());
return new_loop_vars[index - physical_indices.begin()];
});

if (has_test) {
// Create an If statement
Expand Down
8 changes: 3 additions & 5 deletions taichi/transforms/inlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,10 @@ class Inliner : public BasicStmtVisitor {
TI_ASSERT(func->rets.size() <= 1);
auto inlined_ir = irpass::analysis::clone(func->ir.get());
if (!func->args.empty()) {
// TODO: Make sure that if stmt->args is an ArgLoadStmt,
// it will not be replaced again here
irpass::replace_statements_with(
irpass::replace_statements(
inlined_ir.get(),
/*filter=*/[&](Stmt *s) { return s->is<ArgLoadStmt>(); },
/*generator=*/
/*finder=*/
[&](Stmt *s) { return stmt->args[s->as<ArgLoadStmt>()->arg_id]; });
}
if (func->rets.empty()) {
Expand All @@ -46,7 +44,7 @@ class Inliner : public BasicStmtVisitor {
// Use a local variable to store the return value
auto *return_address = inlined_ir->as<Block>()->insert(
Stmt::make<AllocaStmt>(func->rets[0].dt), /*location=*/0);
irpass::replace_statements_with(
irpass::replace_and_insert_statements(
inlined_ir.get(),
/*filter=*/[&](Stmt *s) { return s->is<KernelReturnStmt>(); },
/*generator=*/
Expand Down
30 changes: 30 additions & 0 deletions taichi/transforms/replace_statements.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "taichi/ir/transforms.h"

TLANG_NAMESPACE_BEGIN

namespace irpass {

bool replace_and_insert_statements(
IRNode *root,
std::function<bool(Stmt *)> filter,
std::function<std::unique_ptr<Stmt>(Stmt *)> generator) {
return transform_statements(root, std::move(filter),
[&](Stmt *stmt, DelayedIRModifier *modifier) {
modifier->replace_with(stmt, generator(stmt));
});
}

bool replace_statements(IRNode *root,
std::function<bool(Stmt *)> filter,
std::function<Stmt *(Stmt *)> finder) {
return transform_statements(
root, std::move(filter), [&](Stmt *stmt, DelayedIRModifier *modifier) {
auto existing_new_stmt = finder(stmt);
irpass::replace_all_usages_with(root, stmt, existing_new_stmt);
modifier->erase(stmt);
});
}

} // namespace irpass

TLANG_NAMESPACE_END
155 changes: 0 additions & 155 deletions taichi/transforms/statement_replace.cpp

This file was deleted.

61 changes: 61 additions & 0 deletions taichi/transforms/transform_statements.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"

TLANG_NAMESPACE_BEGIN

// Transform each filtered statement
class StatementsTransformer : public BasicStmtVisitor {
public:
using BasicStmtVisitor::visit;

StatementsTransformer(
std::function<bool(Stmt *)> filter,
std::function<void(Stmt *, DelayedIRModifier *)> transformer)
: filter_(std::move(filter)), transformer_(std::move(transformer)) {
allow_undefined_visitor = true;
invoke_default_visitor = true;
}

void maybe_transform(Stmt *stmt) {
if (filter_(stmt)) {
transformer_(stmt, &modifier_);
}
}

void preprocess_container_stmt(Stmt *stmt) override {
maybe_transform(stmt);
}

void visit(Stmt *stmt) override {
maybe_transform(stmt);
}

static bool run(IRNode *root,
std::function<bool(Stmt *)> filter,
std::function<void(Stmt *, DelayedIRModifier *)> replacer) {
StatementsTransformer transformer(std::move(filter), std::move(replacer));
root->accept(&transformer);
return transformer.modifier_.modify_ir();
}

private:
std::function<bool(Stmt *)> filter_;
std::function<void(Stmt *, DelayedIRModifier *)> transformer_;
DelayedIRModifier modifier_;
};

namespace irpass {

bool transform_statements(
IRNode *root,
std::function<bool(Stmt *)> filter,
std::function<void(Stmt *, DelayedIRModifier *)> transformer) {
return StatementsTransformer::run(root, std::move(filter),
std::move(transformer));
}

} // namespace irpass

TLANG_NAMESPACE_END