Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Opt] Dive into container statements to find local loads/stores for optimization, and optimize loads of new allocas to 0 #662

Merged
merged 12 commits into from
Mar 27, 2020
267 changes: 216 additions & 51 deletions taichi/transforms/simplify.cpp
@@ -1,9 +1,132 @@
#include <set>
#include <unordered_set>
#include <utility>
#include "taichi/ir/ir.h"

TLANG_NAMESPACE_BEGIN

// Find if there is a load following a store in a basic block
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that LocalLoadSearcher is just finding loads, since LocalStore is not mentioned here. I'm confused...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, LocalLoadSearcher is finding loads all over the root, with root actually be a statement after LocalStore, passed in as an argument. I'd better change the comment to simply Find if there is a load.

class LocalLoadSearcher : public BasicStmtVisitor {
private:
Stmt *var;
bool result;

public:
using BasicStmtVisitor::visit;

explicit LocalLoadSearcher(Stmt *var) : var(var), result(false) {
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved
allow_undefined_visitor = true;
invoke_default_visitor = true;
}

void visit(LocalLoadStmt *stmt) override {
if (stmt->has_source(var)) {
result = true;
}
}

void visit(AtomicOpStmt *stmt) override {
if (stmt->dest == var) {
result = true;
}
// current store: $d
yuanming-hu marked this conversation as resolved.
Show resolved Hide resolved
// $a = alloca
// $b : local store [$a <- v1] <-- prev lstore |bstmt_|
// $c = atomic add($a, v2) <-- cannot eliminate $b
// $d : local store [$a <- v3]

// current store: $b
// $a = alloca
// $b : local store [$a <- v1]
// $c = atomic add($a, v2) <-- cannot eliminate $b
}

static bool run(IRNode *root, Stmt *var) {
LocalLoadSearcher searcher(var);
root->accept(&searcher);
return searcher.result;
}
};

// Find if there is a store preceding a load in a basic block
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, isn't this simply Find if there is a store?

class LocalStoreSearcher : public BasicStmtVisitor {
private:
const std::vector<Stmt *> &vars;
bool result;

public:
using BasicStmtVisitor::visit;

explicit LocalStoreSearcher(const std::vector<Stmt *> &vars)
: vars(vars), result(false) {
allow_undefined_visitor = true;
invoke_default_visitor = true;
}

void visit(LocalStoreStmt *stmt) override {
for (auto var : vars) {
if (stmt->ptr == var) {
result = true;
break;
}
}
}

void visit(AtomicOpStmt *stmt) override {
for (auto var : vars) {
if (stmt->dest == var) {
result = true;
break;
}
}
}

static bool run(IRNode *root, const std::vector<Stmt *> &vars) {
LocalStoreSearcher searcher(vars);
root->accept(&searcher);
return searcher.result;
}
};

// Find the **last** store preceding a load in a basic block
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, how does this ensure preceding?

class LocalStoreForwarder : public BasicStmtVisitor {
private:
Stmt *var;
Stmt *result;

public:
using BasicStmtVisitor::visit;

explicit LocalStoreForwarder(Stmt *var) : var(var), result(nullptr) {
allow_undefined_visitor = true;
invoke_default_visitor = true;
}

void visit(LocalStoreStmt *stmt) override {
if (stmt->ptr == var) {
result = stmt;
}
}

void visit(AllocaStmt *stmt) override {
if (stmt == var) {
result = stmt;
}
}

void visit(AtomicOpStmt *stmt) override {
if (stmt->dest == var) {
result = nullptr;
}
}

static Stmt *run(IRNode *root, Stmt *var) {
LocalStoreForwarder searcher(var);
root->accept(&searcher);
return searcher.result;
}
};

// Common subexpression elimination, store forwarding, useless local store
// elimination; Simplify if statements into conditional stores.
class BasicBlockSimplify : public IRVisitor {
Expand Down Expand Up @@ -260,13 +383,20 @@ class BasicBlockSimplify : public IRVisitor {
// no store to the var?
bool has_related_store = false;
for (int j = i + 1; j < current_stmt_id; j++) {
if (block->statements[j]
->is_container_statement()) { // no if, while, etc..
has_related_store = true;
break;
if (!advanced_optimization) {
if (block->statements[j]
->is_container_statement()) { // no if, while, etc..
has_related_store = true;
break;
}
if (modifies_local(block->statements[j].get(), vars)) {
has_related_store = true;
}
continue;
}
if (modifies_local(block->statements[j].get(), vars)) {
if (LocalStoreSearcher::run(block->statements[j].get(), vars)) {
has_related_store = true;
break;
}
}
if (!has_related_store) {
Expand All @@ -290,30 +420,50 @@ class BasicBlockSimplify : public IRVisitor {
if (regular) {
// Check all previous statements in the current block before the local
// load
auto block = stmt->parent;
Stmt *containing_statement = stmt;
auto stmt_id = block->locate(containing_statement);
TI_ASSERT(stmt_id != -1);
for (int i = stmt_id - 1; i >= 0; i--) {
auto &bstmt = block->statements[i];
// Find a previous store
if (auto s = bstmt->cast<AtomicOpStmt>()) {
if (s->dest == alloca) {
if (!advanced_optimization) {
auto &bstmt = block->statements[i];
// Find a previous store
if (auto s = bstmt->cast<AtomicOpStmt>()) {
if (s->dest == alloca) {
break;
}
}
if (bstmt->is<LocalStoreStmt>()) {
auto bstmt_ = bstmt->as<LocalStoreStmt>();
// Same alloca
if (bstmt_->ptr == alloca) {
// Forward to the first local store only
stmt->replace_with(bstmt_->data);
stmt->parent->erase(current_stmt_id);
throw IRModified();
}
} else if (bstmt->is_container_statement()) {
// assume this container may modify the local var
break;
}
continue;
}
if (bstmt->is<LocalStoreStmt>()) {
auto bstmt_ = bstmt->as<LocalStoreStmt>();
// Same alloca
if (bstmt_->ptr == alloca) {
auto bstmt =
LocalStoreForwarder::run(block->statements[i].get(), alloca);
if (bstmt != nullptr) {
if (bstmt->is<LocalStoreStmt>()) {
// Forward to the first local store only
stmt->replace_with(bstmt_->data);
stmt->replace_with(bstmt->as<LocalStoreStmt>()->data);
stmt->parent->erase(current_stmt_id);
throw IRModified();
} else {
TI_ASSERT(bstmt->is<AllocaStmt>());
auto zero = stmt->insert_after_me(Stmt::make<ConstStmt>(
LaneAttribute<TypedConstant>(bstmt->ret_type.data_type)));
zero->repeat(stmt->width());
stmt->replace_with(zero);
stmt->parent->erase(current_stmt_id);
throw IRModified();
}
} else if (bstmt->is_container_statement()) {
// assume this container may modify the local var
break;
}
}
// Note: simply checking all statements before stmt is not sufficient
Expand All @@ -339,24 +489,32 @@ class BasicBlockSimplify : public IRVisitor {
if (same) {
bool has_load = false;
for (int j = i + 1; j < current_stmt_id; j++) {
if (block->statements[j]
->is_container_statement()) { // no if, while, etc..
has_load = true;
break;
if (!advanced_optimization) {
if (block->statements[j]
->is_container_statement()) { // no if, while, etc..
has_load = true;
break;
}
if (block->statements[j]->is<LocalLoadStmt>() &&
block->statements[j]->as<LocalLoadStmt>()->has_source(
stmt->ptr)) {
has_load = true;
}
if (block->statements[j]->is<AtomicOpStmt>() &&
(block->statements[j]->as<AtomicOpStmt>()->dest ==
stmt->ptr)) {
// $a = alloca
// $b : local store [$a <- v1] <-- prev lstore |bstmt_|
// $c = atomic add($a, v2) <-- cannot eliminate $b
// $d : local store [$a <- v3]
has_load = true;
}
continue;
}
if (block->statements[j]->is<LocalLoadStmt>() &&
block->statements[j]->as<LocalLoadStmt>()->has_source(
stmt->ptr)) {
has_load = true;
}
if (block->statements[j]->is<AtomicOpStmt>() &&
(block->statements[j]->as<AtomicOpStmt>()->dest ==
stmt->ptr)) {
// $a = alloca
// $b : local store [$a <- v1] <-- prev lstore |bstmt_|
// $c = atomic add($a, v2) <-- cannot eliminate $b
// $d : local store [$a <- v3]
if (LocalLoadSearcher::run(block->statements[j].get(),
stmt->ptr)) {
has_load = true;
break;
}
}
if (!has_load) {
Expand All @@ -374,27 +532,34 @@ class BasicBlockSimplify : public IRVisitor {
bool has_related = false;
for (int i = current_stmt_id + 1; i < (int)block->statements.size();
i++) {
auto &bstmt = block->statements[i];
if (bstmt->is_container_statement()) {
has_related = true;
break;
}
if (bstmt->is<LocalLoadStmt>()) {
auto bstmt_ = bstmt->as<LocalLoadStmt>();
if (bstmt_->has_source(stmt->ptr)) {
if (!advanced_optimization) {
auto &bstmt = block->statements[i];
if (bstmt->is_container_statement()) {
has_related = true;
break;
}
}
if (bstmt->is<AtomicOpStmt>()) {
// $a = alloca
// $b : local store [$a <- v1]
// $c = atomic add($a, v2) <-- cannot eliminate $b
auto bstmt_ = bstmt->as<AtomicOpStmt>();
if (bstmt_->dest == stmt->ptr) {
has_related = true;
break;
if (bstmt->is<LocalLoadStmt>()) {
auto bstmt_ = bstmt->as<LocalLoadStmt>();
if (bstmt_->has_source(stmt->ptr)) {
has_related = true;
break;
}
}
if (bstmt->is<AtomicOpStmt>()) {
// $a = alloca
// $b : local store [$a <- v1]
// $c = atomic add($a, v2) <-- cannot eliminate $b
auto bstmt_ = bstmt->as<AtomicOpStmt>();
if (bstmt_->dest == stmt->ptr) {
has_related = true;
break;
}
}
continue;
}
if (LocalLoadSearcher::run(block->statements[i].get(), stmt->ptr)) {
has_related = true;
break;
}
}
if (!has_related) {
Expand Down