Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autodiff] Switch off parts of store forwarding optimization for autodiff #5464

Merged
merged 11 commits into from
Jul 21, 2022
61 changes: 24 additions & 37 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access) {
}
}

bool CFGNode::store_to_load_forwarding(bool after_lower_access) {
bool CFGNode::store_to_load_forwarding(bool after_lower_access,
bool autodiff_enabled) {
bool modified = false;
for (int i = begin_location; i < end_location; i++) {
// Store-to-load forwarding
Expand All @@ -274,22 +275,8 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) {
result = get_store_forwarding_data(alloca, i);
}
} else if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
if (!after_lower_access) {
bool store_forwarding = true;
if (auto global_ptr = global_load->src->cast<GlobalPtrStmt>()) {
TI_ASSERT(global_ptr->width() == 1);
auto &snodes = global_ptr->snodes;
if (snodes[0]->has_adjoint()) {
// Has adjoint SNode. Skip the store forwarding
// to keep the global load chain,
// so that the grad of intermidiate variable can be computed
// by GlobalLoadStmt
store_forwarding = false;
}
}
if (store_forwarding) {
result = get_store_forwarding_data(global_load->src, i);
}
if (!after_lower_access && !autodiff_enabled) {
result = get_store_forwarding_data(global_load->src, i);
}
}
if (result) {
Expand All @@ -311,31 +298,29 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) {
// Identical store elimination
if (auto local_store = stmt->cast<LocalStoreStmt>()) {
result = get_store_forwarding_data(local_store->dest, i);
if (result) {
if (result->is<AllocaStmt>()) {
// special case of alloca (initialized to 0)
if (auto stored_data = local_store->val->cast<ConstStmt>()) {
bool all_zero = true;
for (auto &val : stored_data->val.data) {
if (!val.equal_value(0)) {
all_zero = false;
break;
}
}
if (all_zero) {
erase(i); // This causes end_location--
i--; // to cancel i++ in the for loop
modified = true;
if (result && result->is<AllocaStmt>() && !autodiff_enabled) {
// special case of alloca (initialized to 0)
if (auto stored_data = local_store->val->cast<ConstStmt>()) {
bool all_zero = true;
for (auto &val : stored_data->val.data) {
if (!val.equal_value(0)) {
all_zero = false;
break;
}
}
} else {
// not alloca
if (irpass::analysis::same_value(result, local_store->val)) {
if (all_zero) {
erase(i); // This causes end_location--
i--; // to cancel i++ in the for loop
modified = true;
}
}
} else {
// not alloca
if (irpass::analysis::same_value(result, local_store->val)) {
erase(i); // This causes end_location--
i--; // to cancel i++ in the for loop
modified = true;
}
}
} else if (auto global_store = stmt->cast<GlobalStoreStmt>()) {
if (!after_lower_access) {
Expand Down Expand Up @@ -857,13 +842,15 @@ bool ControlFlowGraph::unreachable_code_elimination() {
return modified;
}

bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access) {
bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access,
bool autodiff_enabled) {
TI_AUTO_PROF;
reaching_definition_analysis(after_lower_access);
const int num_nodes = size();
bool modified = false;
for (int i = 0; i < num_nodes; i++) {
if (nodes[i]->store_to_load_forwarding(after_lower_access))
if (nodes[i]->store_to_load_forwarding(after_lower_access,
autodiff_enabled))
modified = true;
}
return modified;
Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/control_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class CFGNode {

// Analyses and optimizations inside a CFGNode.
void reaching_definition_analysis(bool after_lower_access);
bool store_to_load_forwarding(bool after_lower_access);
bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled);
void gather_loaded_snodes(std::unordered_set<SNode *> &snodes) const;
void live_variable_analysis(bool after_lower_access);
bool dead_store_elimination(bool after_lower_access);
Expand Down Expand Up @@ -145,7 +145,7 @@ class ControlFlowGraph {
/**
* Perform store-to-load forwarding and identical store elimination.
*/
bool store_to_load_forwarding(bool after_lower_access);
bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled);

/**
* Perform dead store elimination and identical load elimination.
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ bool simplify(IRNode *root, const CompileConfig &config);
bool cfg_optimization(
IRNode *root,
bool after_lower_access,
bool autodiff_enabled,
const std::optional<ControlFlowGraph::LiveVarAnalysisConfig>
&lva_config_opt = std::nullopt);
bool alg_simp(IRNode *root, const CompileConfig &config);
Expand Down
8 changes: 5 additions & 3 deletions taichi/program/ir_bank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) {
}

irpass::full_simplify(task_a, kernel->program->config,
{/*after_lower_access=*/false, kernel->program});
{/*after_lower_access=*/false,
/*autodiff_enabled*/ false, kernel->program});
// For now, re_id is necessary for the hash to be correct.
irpass::re_id(task_a);

Expand Down Expand Up @@ -208,8 +209,9 @@ std::pair<IRHandle, bool> IRBank::optimize_dse(
}
ControlFlowGraph::LiveVarAnalysisConfig lva_config;
lva_config.eliminable_snodes = {snodes.begin(), snodes.end()};
const bool modified = irpass::cfg_optimization(
new_ir.get(), /*after_lower_access=*/false, lva_config);
const bool modified =
irpass::cfg_optimization(new_ir.get(), /*after_lower_access=*/false,
/*autodiff_enabled*/ false, lva_config);
if (verbose) {
TI_INFO(" DSE: after CFG, modified={}", modified);
std::cout << std::flush;
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ class MakeAdjoint : public ADTransform {
}

Stmt *adjoint(Stmt *stmt) {
if (!is_real(stmt->ret_type)) {
if (!is_real(stmt->ret_type) || stmt->is<ConstStmt>()) {
return constant(0);
}
if (adjoint_stmt.find(stmt) == adjoint_stmt.end()) {
Expand Down
3 changes: 2 additions & 1 deletion taichi/transforms/cfg_optimization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace irpass {
bool cfg_optimization(
IRNode *root,
bool after_lower_access,
bool autodiff_enabled,
const std::optional<ControlFlowGraph::LiveVarAnalysisConfig>
&lva_config_opt) {
TI_AUTO_PROF;
Expand All @@ -18,7 +19,7 @@ bool cfg_optimization(
while (true) {
bool modified = false;
cfg->simplify_graph();
if (cfg->store_to_load_forwarding(after_lower_access))
if (cfg->store_to_load_forwarding(after_lower_access, autodiff_enabled))
modified = true;
if (cfg->dead_store_elimination(after_lower_access, lva_config_opt))
modified = true;
Expand Down
30 changes: 21 additions & 9 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ void compile_to_offloads(IRNode *ir,
irpass::analysis::verify(ir);
}

irpass::full_simplify(ir, config, {false, kernel->program});
irpass::full_simplify(
ir, config,
{false, /*autodiff_enabled*/ autodiff_mode != AutodiffMode::kNone,
kernel->program});
print("Simplified I");
irpass::analysis::verify(ir);

Expand All @@ -89,9 +92,12 @@ void compile_to_offloads(IRNode *ir,
// Remove local atomics here so that we don't have to handle their gradients
irpass::demote_atomics(ir, config);

irpass::full_simplify(ir, config, {false, kernel->program});
irpass::full_simplify(ir, config,
{false, /*autodiff_enabled*/ true, kernel->program});
irpass::auto_diff(ir, config, autodiff_mode, ad_use_stack);
irpass::full_simplify(ir, config, {false, kernel->program});
// TODO: Be carefull with the full_simplify when do high-order autodiff
irpass::full_simplify(ir, config,
{false, /*autodiff_enabled*/ false, kernel->program});
print("Gradient");
irpass::analysis::verify(ir);
}
Expand All @@ -106,7 +112,8 @@ void compile_to_offloads(IRNode *ir,
print("Access flagged I");
irpass::analysis::verify(ir);

irpass::full_simplify(ir, config, {false, kernel->program});
irpass::full_simplify(ir, config,
{false, /*autodiff_enabled*/ false, kernel->program});
print("Simplified II");
irpass::analysis::verify(ir);

Expand All @@ -117,15 +124,16 @@ void compile_to_offloads(IRNode *ir,
// TODO: This pass may be redundant as cfg_optimization() is already called
// in full_simplify().
if (config.opt_level > 0 && config.cfg_optimization) {
irpass::cfg_optimization(ir, false);
irpass::cfg_optimization(ir, false, /*autodiff_enabled*/ false);
print("Optimized by CFG");
irpass::analysis::verify(ir);
}

irpass::flag_access(ir);
print("Access flagged II");

irpass::full_simplify(ir, config, {false, kernel->program});
irpass::full_simplify(ir, config,
{false, /*autodiff_enabled*/ false, kernel->program});
print("Simplified III");
irpass::analysis::verify(ir);
}
Expand Down Expand Up @@ -187,7 +195,8 @@ void offload_to_executable(IRNode *ir,
if (config.make_mesh_block_local && config.arch == Arch::cuda) {
irpass::make_mesh_block_local(ir, config, {kernel->get_name()});
print("Make mesh block local");
irpass::full_simplify(ir, config, {false, kernel->program});
irpass::full_simplify(
ir, config, {false, /*autodiff_enabled*/ false, kernel->program});
print("Simplified X");
}
}
Expand Down Expand Up @@ -235,7 +244,9 @@ void offload_to_executable(IRNode *ir,
irpass::demote_operations(ir, config);
print("Operations demoted");

irpass::full_simplify(ir, config, {lower_global_access, kernel->program});
irpass::full_simplify(
ir, config,
{lower_global_access, /*autodiff_enabled*/ false, kernel->program});
print("Simplified IV");

if (determine_ad_stack_size) {
Expand Down Expand Up @@ -311,7 +322,8 @@ void compile_function(IRNode *ir,
irpass::type_check(ir, config);
print("Typechecked");

irpass::full_simplify(ir, config, {false, func->program});
irpass::full_simplify(
ir, config, {false, autodiff_mode != AutodiffMode::kNone, func->program});
print("Simplified");
irpass::analysis::verify(ir);
}
Expand Down
3 changes: 2 additions & 1 deletion taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,8 @@ void full_simplify(IRNode *root,
// not modified.
if (config.opt_level > 0 && (first_iteration || modified) &&
config.cfg_optimization &&
cfg_optimization(root, args.after_lower_access))
cfg_optimization(root, args.after_lower_access,
args.autodiff_enabled))
modified = true;
first_iteration = false;
if (!modified)
Expand Down
3 changes: 3 additions & 0 deletions taichi/transforms/simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ class FullSimplifyPass : public Pass {

struct Args {
bool after_lower_access;
// Switch off some optimization in store forwarding if there is an autodiff
// pass after the full_simplify
bool autodiff_enabled;
Program *program;
};
};
Expand Down
3 changes: 2 additions & 1 deletion tests/cpp/transforms/inlining_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ TEST_F(InliningTest, ArgLoadOfArgLoad) {
irpass::type_check(kernel_block, CompileConfig());

irpass::inlining(kernel_block, CompileConfig(), {});
irpass::full_simplify(kernel_block, CompileConfig(), {false, prog_.get()});
irpass::full_simplify(kernel_block, CompileConfig(),
{false, false, prog_.get()});

EXPECT_EQ(kernel_block->size(), 4);
EXPECT_TRUE(irpass::analysis::same_statements(func_block, kernel_block));
Expand Down
8 changes: 6 additions & 2 deletions tests/python/test_ad_for.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,9 @@ def test_inner_loops_local_variable():
assert x.grad[None] == 36.0


@test_utils.test(require=ti.extension.adstack, ad_stack_size=0)
@test_utils.test(require=ti.extension.adstack,
ad_stack_size=0,
exclude=[ti.cc])
def test_more_inner_loops_local_variable_adaptive_stack_size_tape():
x = ti.field(dtype=float, shape=(), needs_grad=True)
arr = ti.field(dtype=float, shape=(2), needs_grad=True)
Expand All @@ -590,7 +592,9 @@ def test_more_inner_loops_local_variable():
assert x.grad[None] == 36.0


@test_utils.test(require=ti.extension.adstack, ad_stack_size=32)
@test_utils.test(require=ti.extension.adstack,
ad_stack_size=32,
exclude=[ti.cc])
def test_more_inner_loops_local_variable_fixed_stack_size_tape():
x = ti.field(dtype=float, shape=(), needs_grad=True)
arr = ti.field(dtype=float, shape=(2), needs_grad=True)
Expand Down
41 changes: 40 additions & 1 deletion tests/python/test_ad_for_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def fib():


@test_utils.test()
def test_ad_fibonacci_index():
def test_ad_fibonacci_index_fwd():
N = 5
M = 10
a = ti.field(ti.f32, shape=M)
Expand Down Expand Up @@ -173,3 +173,42 @@ def fib():
is_fib = int(i in [1, 2, 3, 5, 8])
assert f.dual[i] == is_fib * N
assert b[i] == is_fib * N


@test_utils.test(exclude=[ti.cc])
def test_double_for_loops():
N = 5
a = ti.field(ti.f32, shape=N)
b = ti.field(ti.f32, shape=N)
c = ti.field(ti.i32, shape=N)
f = ti.field(ti.f32, shape=N)
ti.root.lazy_dual()

@ti.kernel
def double_for():
for i in range(N):
weight = 1.0
for j in range(c[i]):
weight *= a[i]
s = 0.0
for j in range(c[i] * 2):
s += weight + b[i]
f[i] = s

a.fill(2)
b.fill(1)

for i in range(N):
c[i] = i

with ti.ad.FwdMode(loss=f, parameters=a, seed=[1.0 for _ in range(N)]):
double_for()

for i in range(N):
assert f.dual[i] == 2 * i * i * 2**(i - 1)

with ti.ad.FwdMode(loss=f, parameters=b, seed=[1.0 for _ in range(N)]):
double_for()

for i in range(N):
assert f.dual[i] == 2 * i