Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Opt] [bug] Better aliasing analysis for dead store elimination #1432

Merged
merged 2 commits into from Jul 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions taichi/analysis/data_source_analysis.cpp
Expand Up @@ -26,6 +26,12 @@ std::vector<Stmt *> get_load_pointers(Stmt *load_stmt) {
} else if (auto stack_acc_adj = load_stmt->cast<StackAccAdjointStmt>()) {
// This statement loads and stores the adjoint data.
return std::vector<Stmt *>(1, stack_acc_adj->stack);
} else if (auto stack_push = load_stmt->cast<StackPushStmt>()) {
// This is to make dead store elimination not eliminate consequent pushes.
return std::vector<Stmt *>(1, stack_push->stack);
} else if (auto stack_pop = load_stmt->cast<StackPopStmt>()) {
// This is to make dead store elimination not eliminate consequent pops.
return std::vector<Stmt *>(1, stack_pop->stack);
} else {
return std::vector<Stmt *>();
}
Expand Down
41 changes: 32 additions & 9 deletions taichi/ir/control_flow_graph.cpp
Expand Up @@ -82,8 +82,27 @@ bool CFGNode::contain_variable(const std::unordered_set<Stmt *> &var_set,
return var_set.find(var) != var_set.end();
} else {
// TODO: How to optimize this?
if (var_set.find(var) != var_set.end())
return true;
for (auto set_var : var_set) {
if (irpass::analysis::same_statements(var, set_var)) {
if (definitely_same_address(var, set_var)) {
return true;
}
}
return false;
}
}

bool CFGNode::may_contain_variable(const std::unordered_set<Stmt *> &var_set,
Stmt *var) {
if (var->is<AllocaStmt>() || var->is<StackAllocaStmt>()) {
return var_set.find(var) != var_set.end();
} else {
// TODO: How to optimize this?
if (var_set.find(var) != var_set.end())
return true;
for (auto set_var : var_set) {
if (maybe_same_address(var, set_var)) {
return true;
}
}
Expand Down Expand Up @@ -290,16 +309,18 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
store_ptr = stack_push->stack;
} else if (auto stack_acc_adj = stmt->cast<StackAccAdjointStmt>()) {
store_ptr = stack_acc_adj->stack;
} else if (stmt->is<StackAllocaStmt>()) {
store_ptr = stmt;
}
if (store_ptr) {
if (!after_lower_access ||
(store_ptr->is<AllocaStmt>() || store_ptr->is<StackAllocaStmt>())) {
// After lower_access, we only analyze local variables and stacks.
// Do not eliminate AllocaStmt here.
if (!stmt->is<AllocaStmt>() &&
// Do not eliminate AllocaStmt and StackAllocaStmt here.
if (!stmt->is<AllocaStmt>() && !stmt->is<StackAllocaStmt>() &&
!may_contain_variable(live_in_this_node, store_ptr) &&
(contain_variable(killed_in_this_node, store_ptr) ||
(!contain_variable(live_out, store_ptr) &&
!contain_variable(live_in_this_node, store_ptr)))) {
!may_contain_variable(live_out, store_ptr))) {
// Neither used in other nodes nor used in this node.
if (auto atomic = stmt->cast<AtomicOpStmt>()) {
// Weaken the atomic operation to a load.
Expand All @@ -309,7 +330,6 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
local_load->ret_type = atomic->ret_type;
replace_with(i, std::move(local_load), true);
// Notice that we have a load here.
killed_in_this_node.erase(atomic->dest);
live_in_this_node.insert(atomic->dest);
modified = true;
continue;
Expand All @@ -322,7 +342,6 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
global_load->ret_type = atomic->ret_type;
replace_with(i, std::move(global_load), true);
// Notice that we have a load here.
killed_in_this_node.erase(atomic->dest);
live_in_this_node.insert(atomic->dest);
modified = true;
continue;
Expand All @@ -335,7 +354,12 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
} else {
// A non-eliminated store.
killed_in_this_node.insert(store_ptr);
live_in_this_node.erase(store_ptr);
auto old_live_in_this_node = std::move(live_in_this_node);
live_in_this_node.clear();
for (auto &var : old_live_in_this_node) {
if (!definitely_same_address(store_ptr, var))
live_in_this_node.insert(var);
}
}
}
}
Expand All @@ -344,7 +368,6 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
if (!after_lower_access ||
(load_ptr->is<AllocaStmt>() || load_ptr->is<StackAllocaStmt>())) {
// After lower_access, we only analyze local variables and stacks.
killed_in_this_node.erase(load_ptr);
live_in_this_node.insert(load_ptr);
}
}
Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/control_flow_graph.h
Expand Up @@ -53,6 +53,8 @@ class CFGNode {

static bool contain_variable(const std::unordered_set<Stmt *> &var_set,
Stmt *var);
static bool may_contain_variable(const std::unordered_set<Stmt *> &var_set,
Stmt *var);
void reaching_definition_analysis(bool after_lower_access);
bool reach_kill_variable(Stmt *var) const;
Stmt *get_store_forwarding_data(Stmt *var, int position) const;
Expand Down
62 changes: 62 additions & 0 deletions taichi/ir/ir.cpp
Expand Up @@ -24,6 +24,51 @@ CompileConfig &IRNode::get_config() const {
return get_kernel()->program.config;
}

bool definitely_same_address(Stmt *var1, Stmt *var2) {
// Return true when two statements must be the same address;
// false when two statements can be different addresses.

// If both stmts are allocas, they have the same address iff var1 == var2.
// If only one of them is an alloca, they can never share the same address.
if (var1 == var2)
return true;
if (!var1 || !var2)
return false;
if (var1->is<AllocaStmt>() || var2->is<AllocaStmt>())
return false;
if (var1->is<StackAllocaStmt>() || var2->is<StackAllocaStmt>())
return false;

// TODO(xumingkuan): Put GlobalTemporaryStmt, ThreadLocalPtrStmt and
// BlockLocalPtrStmt into GlobalPtrStmt.
// If both statements are global temps, they have the same address iff they
// have the same offset. If only one of them is a global temp, they can never
// share the same address.
if (var1->is<GlobalTemporaryStmt>() || var2->is<GlobalTemporaryStmt>()) {
if (!var1->is<GlobalTemporaryStmt>() || !var2->is<GlobalTemporaryStmt>())
return false;
return var1->as<GlobalTemporaryStmt>()->offset ==
var2->as<GlobalTemporaryStmt>()->offset;
}

if (var1->is<ThreadLocalPtrStmt>() || var2->is<ThreadLocalPtrStmt>()) {
if (!var1->is<ThreadLocalPtrStmt>() || !var2->is<ThreadLocalPtrStmt>())
return false;
return var1->as<ThreadLocalPtrStmt>()->offset ==
var2->as<ThreadLocalPtrStmt>()->offset;
}

if (var1->is<BlockLocalPtrStmt>() || var2->is<BlockLocalPtrStmt>()) {
if (!var1->is<BlockLocalPtrStmt>() || !var2->is<BlockLocalPtrStmt>())
return false;
return irpass::analysis::same_statements(
var1->as<BlockLocalPtrStmt>()->offset,
var2->as<BlockLocalPtrStmt>()->offset);
}

return irpass::analysis::same_statements(var1, var2);
}

bool maybe_same_address(Stmt *var1, Stmt *var2) {
// Return true when two statements might be the same address;
// false when two statements cannot be the same address.
Expand All @@ -36,6 +81,8 @@ bool maybe_same_address(Stmt *var1, Stmt *var2) {
return false;
if (var1->is<AllocaStmt>() || var2->is<AllocaStmt>())
return false;
if (var1->is<StackAllocaStmt>() || var2->is<StackAllocaStmt>())
return false;

// If both statements are global temps, they have the same address iff they
// have the same offset. If only one of them is a global temp, they can never
Expand All @@ -47,6 +94,21 @@ bool maybe_same_address(Stmt *var1, Stmt *var2) {
var2->as<GlobalTemporaryStmt>()->offset;
}

if (var1->is<ThreadLocalPtrStmt>() || var2->is<ThreadLocalPtrStmt>()) {
if (!var1->is<ThreadLocalPtrStmt>() || !var2->is<ThreadLocalPtrStmt>())
return false;
return var1->as<ThreadLocalPtrStmt>()->offset ==
var2->as<ThreadLocalPtrStmt>()->offset;
}

if (var1->is<BlockLocalPtrStmt>() || var2->is<BlockLocalPtrStmt>()) {
if (!var1->is<BlockLocalPtrStmt>() || !var2->is<BlockLocalPtrStmt>())
return false;
return irpass::analysis::same_statements(
var1->as<BlockLocalPtrStmt>()->offset,
var2->as<BlockLocalPtrStmt>()->offset);
}

// If both statements are GlobalPtrStmts or GetChStmts, we can check by
// SNode::id.
TI_ASSERT(var1->width() == 1);
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/ir.h
Expand Up @@ -33,6 +33,7 @@ using ScratchPadOptions = std::vector<std::pair<int, SNode *>>;

IRBuilder &current_ast_builder();

bool definitely_same_address(Stmt *var1, Stmt *var2);
bool maybe_same_address(Stmt *var1, Stmt *var2);

struct VectorType {
Expand Down
12 changes: 6 additions & 6 deletions tests/python/test_ad_for.py
Expand Up @@ -10,7 +10,7 @@ def test_ad_sum():
p = ti.var(ti.f32, shape=N, needs_grad=True)

@ti.kernel
def comptue_sum():
def compute_sum():
for i in range(N):
ret = 1.0
for j in range(b[i]):
Expand All @@ -21,13 +21,13 @@ def comptue_sum():
a[i] = 3
b[i] = i

comptue_sum()
compute_sum()

for i in range(N):
assert p[i] == 3 * b[i] + 1
p.grad[i] = 1

comptue_sum.grad()
compute_sum.grad()

for i in range(N):
assert a.grad[i] == b[i]
Expand All @@ -43,7 +43,7 @@ def test_ad_sum_local_atomic():
p = ti.var(ti.f32, shape=N, needs_grad=True)

@ti.kernel
def comptue_sum():
def compute_sum():
for i in range(N):
ret = 1.0
for j in range(b[i]):
Expand All @@ -54,13 +54,13 @@ def comptue_sum():
a[i] = 3
b[i] = i

comptue_sum()
compute_sum()

for i in range(N):
assert p[i] == 3 * b[i] + 1
p.grad[i] = 1

comptue_sum.grad()
compute_sum.grad()

for i in range(N):
assert a.grad[i] == b[i]
Expand Down