Skip to content

Commit

Permalink
[autodiff] Switch off parts of store forwarding optimization for auto…
Browse files Browse the repository at this point in the history
…diff (#5464)

* [autodiff] Switch off parts of store forwarding optimization for autodiff

* update opt args

* update opt args

* update opt args

* update opt args

* exclude cc backend

* no need to push zero to stack

* no need to push zero to stack

* exclude ConstantStmt when generating adjoint

* recover the initial zero for ad stack

* update
  • Loading branch information
erizmr committed Jul 21, 2022
1 parent 7f789d4 commit 523fea6
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 58 deletions.
61 changes: 24 additions & 37 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access) {
}
}

bool CFGNode::store_to_load_forwarding(bool after_lower_access) {
bool CFGNode::store_to_load_forwarding(bool after_lower_access,
bool autodiff_enabled) {
bool modified = false;
for (int i = begin_location; i < end_location; i++) {
// Store-to-load forwarding
Expand All @@ -274,22 +275,8 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) {
result = get_store_forwarding_data(alloca, i);
}
} else if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
if (!after_lower_access) {
bool store_forwarding = true;
if (auto global_ptr = global_load->src->cast<GlobalPtrStmt>()) {
TI_ASSERT(global_ptr->width() == 1);
auto &snodes = global_ptr->snodes;
if (snodes[0]->has_adjoint()) {
// Has adjoint SNode. Skip the store forwarding
// to keep the global load chain,
// so that the grad of intermidiate variable can be computed
// by GlobalLoadStmt
store_forwarding = false;
}
}
if (store_forwarding) {
result = get_store_forwarding_data(global_load->src, i);
}
if (!after_lower_access && !autodiff_enabled) {
result = get_store_forwarding_data(global_load->src, i);
}
}
if (result) {
Expand All @@ -311,31 +298,29 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) {
// Identical store elimination
if (auto local_store = stmt->cast<LocalStoreStmt>()) {
result = get_store_forwarding_data(local_store->dest, i);
if (result) {
if (result->is<AllocaStmt>()) {
// special case of alloca (initialized to 0)
if (auto stored_data = local_store->val->cast<ConstStmt>()) {
bool all_zero = true;
for (auto &val : stored_data->val.data) {
if (!val.equal_value(0)) {
all_zero = false;
break;
}
}
if (all_zero) {
erase(i); // This causes end_location--
i--; // to cancel i++ in the for loop
modified = true;
if (result && result->is<AllocaStmt>() && !autodiff_enabled) {
// special case of alloca (initialized to 0)
if (auto stored_data = local_store->val->cast<ConstStmt>()) {
bool all_zero = true;
for (auto &val : stored_data->val.data) {
if (!val.equal_value(0)) {
all_zero = false;
break;
}
}
} else {
// not alloca
if (irpass::analysis::same_value(result, local_store->val)) {
if (all_zero) {
erase(i); // This causes end_location--
i--; // to cancel i++ in the for loop
modified = true;
}
}
} else {
// not alloca
if (irpass::analysis::same_value(result, local_store->val)) {
erase(i); // This causes end_location--
i--; // to cancel i++ in the for loop
modified = true;
}
}
} else if (auto global_store = stmt->cast<GlobalStoreStmt>()) {
if (!after_lower_access) {
Expand Down Expand Up @@ -857,13 +842,15 @@ bool ControlFlowGraph::unreachable_code_elimination() {
return modified;
}

bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access) {
bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access,
bool autodiff_enabled) {
TI_AUTO_PROF;
reaching_definition_analysis(after_lower_access);
const int num_nodes = size();
bool modified = false;
for (int i = 0; i < num_nodes; i++) {
if (nodes[i]->store_to_load_forwarding(after_lower_access))
if (nodes[i]->store_to_load_forwarding(after_lower_access,
autodiff_enabled))
modified = true;
}
return modified;
Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/control_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class CFGNode {

// Analyses and optimizations inside a CFGNode.
void reaching_definition_analysis(bool after_lower_access);
bool store_to_load_forwarding(bool after_lower_access);
bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled);
void gather_loaded_snodes(std::unordered_set<SNode *> &snodes) const;
void live_variable_analysis(bool after_lower_access);
bool dead_store_elimination(bool after_lower_access);
Expand Down Expand Up @@ -145,7 +145,7 @@ class ControlFlowGraph {
/**
* Perform store-to-load forwarding and identical store elimination.
*/
bool store_to_load_forwarding(bool after_lower_access);
bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled);

/**
* Perform dead store elimination and identical load elimination.
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ bool simplify(IRNode *root, const CompileConfig &config);
bool cfg_optimization(
IRNode *root,
bool after_lower_access,
bool autodiff_enabled,
const std::optional<ControlFlowGraph::LiveVarAnalysisConfig>
&lva_config_opt = std::nullopt);
bool alg_simp(IRNode *root, const CompileConfig &config);
Expand Down
8 changes: 5 additions & 3 deletions taichi/program/ir_bank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) {
}

irpass::full_simplify(task_a, kernel->program->config,
{/*after_lower_access=*/false, kernel->program});
{/*after_lower_access=*/false,
/*autodiff_enabled*/ false, kernel->program});
// For now, re_id is necessary for the hash to be correct.
irpass::re_id(task_a);

Expand Down Expand Up @@ -208,8 +209,9 @@ std::pair<IRHandle, bool> IRBank::optimize_dse(
}
ControlFlowGraph::LiveVarAnalysisConfig lva_config;
lva_config.eliminable_snodes = {snodes.begin(), snodes.end()};
const bool modified = irpass::cfg_optimization(
new_ir.get(), /*after_lower_access=*/false, lva_config);
const bool modified =
irpass::cfg_optimization(new_ir.get(), /*after_lower_access=*/false,
/*autodiff_enabled*/ false, lva_config);
if (verbose) {
TI_INFO(" DSE: after CFG, modified={}", modified);
std::cout << std::flush;
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ class MakeAdjoint : public ADTransform {
}

Stmt *adjoint(Stmt *stmt) {
if (!is_real(stmt->ret_type)) {
if (!is_real(stmt->ret_type) || stmt->is<ConstStmt>()) {
return constant(0);
}
if (adjoint_stmt.find(stmt) == adjoint_stmt.end()) {
Expand Down
3 changes: 2 additions & 1 deletion taichi/transforms/cfg_optimization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace irpass {
bool cfg_optimization(
IRNode *root,
bool after_lower_access,
bool autodiff_enabled,
const std::optional<ControlFlowGraph::LiveVarAnalysisConfig>
&lva_config_opt) {
TI_AUTO_PROF;
Expand All @@ -18,7 +19,7 @@ bool cfg_optimization(
while (true) {
bool modified = false;
cfg->simplify_graph();
if (cfg->store_to_load_forwarding(after_lower_access))
if (cfg->store_to_load_forwarding(after_lower_access, autodiff_enabled))
modified = true;
if (cfg->dead_store_elimination(after_lower_access, lva_config_opt))
modified = true;
Expand Down
30 changes: 21 additions & 9 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ void compile_to_offloads(IRNode *ir,
irpass::analysis::verify(ir);
}

irpass::full_simplify(ir, config, {false, kernel->program});
irpass::full_simplify(
ir, config,
{false, /*autodiff_enabled*/ autodiff_mode != AutodiffMode::kNone,
kernel->program});
print("Simplified I");
irpass::analysis::verify(ir);

Expand All @@ -89,9 +92,12 @@ void compile_to_offloads(IRNode *ir,
// Remove local atomics here so that we don't have to handle their gradients
irpass::demote_atomics(ir, config);

irpass::full_simplify(ir, config, {false, kernel->program});
irpass::full_simplify(ir, config,
{false, /*autodiff_enabled*/ true, kernel->program});
irpass::auto_diff(ir, config, autodiff_mode, ad_use_stack);
irpass::full_simplify(ir, config, {false, kernel->program});
// TODO: Be carefull with the full_simplify when do high-order autodiff
irpass::full_simplify(ir, config,
{false, /*autodiff_enabled*/ false, kernel->program});
print("Gradient");
irpass::analysis::verify(ir);
}
Expand All @@ -106,7 +112,8 @@ void compile_to_offloads(IRNode *ir,
print("Access flagged I");
irpass::analysis::verify(ir);

irpass::full_simplify(ir, config, {false, kernel->program});
irpass::full_simplify(ir, config,
{false, /*autodiff_enabled*/ false, kernel->program});
print("Simplified II");
irpass::analysis::verify(ir);

Expand All @@ -117,15 +124,16 @@ void compile_to_offloads(IRNode *ir,
// TODO: This pass may be redundant as cfg_optimization() is already called
// in full_simplify().
if (config.opt_level > 0 && config.cfg_optimization) {
irpass::cfg_optimization(ir, false);
irpass::cfg_optimization(ir, false, /*autodiff_enabled*/ false);
print("Optimized by CFG");
irpass::analysis::verify(ir);
}

irpass::flag_access(ir);
print("Access flagged II");

irpass::full_simplify(ir, config, {false, kernel->program});
irpass::full_simplify(ir, config,
{false, /*autodiff_enabled*/ false, kernel->program});
print("Simplified III");
irpass::analysis::verify(ir);
}
Expand Down Expand Up @@ -187,7 +195,8 @@ void offload_to_executable(IRNode *ir,
if (config.make_mesh_block_local && config.arch == Arch::cuda) {
irpass::make_mesh_block_local(ir, config, {kernel->get_name()});
print("Make mesh block local");
irpass::full_simplify(ir, config, {false, kernel->program});
irpass::full_simplify(
ir, config, {false, /*autodiff_enabled*/ false, kernel->program});
print("Simplified X");
}
}
Expand Down Expand Up @@ -235,7 +244,9 @@ void offload_to_executable(IRNode *ir,
irpass::demote_operations(ir, config);
print("Operations demoted");

irpass::full_simplify(ir, config, {lower_global_access, kernel->program});
irpass::full_simplify(
ir, config,
{lower_global_access, /*autodiff_enabled*/ false, kernel->program});
print("Simplified IV");

if (determine_ad_stack_size) {
Expand Down Expand Up @@ -311,7 +322,8 @@ void compile_function(IRNode *ir,
irpass::type_check(ir, config);
print("Typechecked");

irpass::full_simplify(ir, config, {false, func->program});
irpass::full_simplify(
ir, config, {false, autodiff_mode != AutodiffMode::kNone, func->program});
print("Simplified");
irpass::analysis::verify(ir);
}
Expand Down
3 changes: 2 additions & 1 deletion taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,8 @@ void full_simplify(IRNode *root,
// not modified.
if (config.opt_level > 0 && (first_iteration || modified) &&
config.cfg_optimization &&
cfg_optimization(root, args.after_lower_access))
cfg_optimization(root, args.after_lower_access,
args.autodiff_enabled))
modified = true;
first_iteration = false;
if (!modified)
Expand Down
3 changes: 3 additions & 0 deletions taichi/transforms/simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ class FullSimplifyPass : public Pass {

struct Args {
bool after_lower_access;
// Switch off some optimization in store forwarding if there is an autodiff
// pass after the full_simplify
bool autodiff_enabled;
Program *program;
};
};
Expand Down
3 changes: 2 additions & 1 deletion tests/cpp/transforms/inlining_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ TEST_F(InliningTest, ArgLoadOfArgLoad) {
irpass::type_check(kernel_block, CompileConfig());

irpass::inlining(kernel_block, CompileConfig(), {});
irpass::full_simplify(kernel_block, CompileConfig(), {false, prog_.get()});
irpass::full_simplify(kernel_block, CompileConfig(),
{false, false, prog_.get()});

EXPECT_EQ(kernel_block->size(), 4);
EXPECT_TRUE(irpass::analysis::same_statements(func_block, kernel_block));
Expand Down
8 changes: 6 additions & 2 deletions tests/python/test_ad_for.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,9 @@ def test_inner_loops_local_variable():
assert x.grad[None] == 36.0


@test_utils.test(require=ti.extension.adstack, ad_stack_size=0)
@test_utils.test(require=ti.extension.adstack,
ad_stack_size=0,
exclude=[ti.cc])
def test_more_inner_loops_local_variable_adaptive_stack_size_tape():
x = ti.field(dtype=float, shape=(), needs_grad=True)
arr = ti.field(dtype=float, shape=(2), needs_grad=True)
Expand All @@ -590,7 +592,9 @@ def test_more_inner_loops_local_variable():
assert x.grad[None] == 36.0


@test_utils.test(require=ti.extension.adstack, ad_stack_size=32)
@test_utils.test(require=ti.extension.adstack,
ad_stack_size=32,
exclude=[ti.cc])
def test_more_inner_loops_local_variable_fixed_stack_size_tape():
x = ti.field(dtype=float, shape=(), needs_grad=True)
arr = ti.field(dtype=float, shape=(2), needs_grad=True)
Expand Down
41 changes: 40 additions & 1 deletion tests/python/test_ad_for_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def fib():


@test_utils.test()
def test_ad_fibonacci_index():
def test_ad_fibonacci_index_fwd():
N = 5
M = 10
a = ti.field(ti.f32, shape=M)
Expand Down Expand Up @@ -173,3 +173,42 @@ def fib():
is_fib = int(i in [1, 2, 3, 5, 8])
assert f.dual[i] == is_fib * N
assert b[i] == is_fib * N


@test_utils.test(exclude=[ti.cc])
def test_double_for_loops():
N = 5
a = ti.field(ti.f32, shape=N)
b = ti.field(ti.f32, shape=N)
c = ti.field(ti.i32, shape=N)
f = ti.field(ti.f32, shape=N)
ti.root.lazy_dual()

@ti.kernel
def double_for():
for i in range(N):
weight = 1.0
for j in range(c[i]):
weight *= a[i]
s = 0.0
for j in range(c[i] * 2):
s += weight + b[i]
f[i] = s

a.fill(2)
b.fill(1)

for i in range(N):
c[i] = i

with ti.ad.FwdMode(loss=f, parameters=a, seed=[1.0 for _ in range(N)]):
double_for()

for i in range(N):
assert f.dual[i] == 2 * i * i * 2**(i - 1)

with ti.ad.FwdMode(loss=f, parameters=b, seed=[1.0 for _ in range(N)]):
double_for()

for i in range(N):
assert f.dual[i] == 2 * i

0 comments on commit 523fea6

Please sign in to comment.