From ae7605333d036c949b73c28c2b9a51b95aef8778 Mon Sep 17 00:00:00 2001 From: mingrui Date: Wed, 6 Jul 2022 17:23:02 +0800 Subject: [PATCH 01/11] [autodiff] Switch off parts of store forwarding optimization for autodiff --- taichi/ir/control_flow_graph.cpp | 53 ++++++++++------------- taichi/ir/control_flow_graph.h | 6 ++- taichi/ir/transforms.h | 1 + taichi/program/ir_bank.cpp | 8 ++-- taichi/transforms/cfg_optimization.cpp | 3 +- taichi/transforms/compile_to_offloads.cpp | 30 +++++++++---- taichi/transforms/simplify.cpp | 3 +- taichi/transforms/simplify.h | 3 ++ tests/cpp/transforms/inlining_test.cpp | 3 +- tests/python/test_ad_for_fwd.py | 41 +++++++++++++++++- 10 files changed, 102 insertions(+), 49 deletions(-) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index b00eb9dc8366b..65c0af9e26c67 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -255,7 +255,8 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access) { } } -bool CFGNode::store_to_load_forwarding(bool after_lower_access) { +bool CFGNode::store_to_load_forwarding(bool after_lower_access, + bool with_autodiff_after) { bool modified = false; for (int i = begin_location; i < end_location; i++) { // Store-to-load forwarding @@ -274,22 +275,8 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) { result = get_store_forwarding_data(alloca, i); } } else if (auto global_load = stmt->cast()) { - if (!after_lower_access) { - bool store_forwarding = true; - if (auto global_ptr = global_load->src->cast()) { - TI_ASSERT(global_ptr->width() == 1); - auto &snodes = global_ptr->snodes; - if (snodes[0]->has_adjoint()) { - // Has adjoint SNode. Skip the store forwarding - // to keep the global load chain, - // so that the grad of intermidiate variable can be computed - // by GlobalLoadStmt - store_forwarding = false; - } - } - if (store_forwarding) { - result = get_store_forwarding_data(global_load->src, i); - } + if (!after_lower_access && !with_autodiff_after) { + result = get_store_forwarding_data(global_load->src, i); } } if (result) { @@ -313,19 +300,21 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) { result = get_store_forwarding_data(local_store->dest, i); if (result) { if (result->is()) { - // special case of alloca (initialized to 0) - if (auto stored_data = local_store->val->cast()) { - bool all_zero = true; - for (auto &val : stored_data->val.data) { - if (!val.equal_value(0)) { - all_zero = false; - break; + if (!with_autodiff_after) { + // special case of alloca (initialized to 0) + if (auto stored_data = local_store->val->cast()) { + bool all_zero = true; + for (auto &val : stored_data->val.data) { + if (!val.equal_value(0)) { + all_zero = false; + break; + } + } + if (all_zero) { + erase(i); // This causes end_location-- + i--; // to cancel i++ in the for loop + modified = true; } - } - if (all_zero) { - erase(i); // This causes end_location-- - i--; // to cancel i++ in the for loop - modified = true; } } } else { @@ -857,13 +846,15 @@ bool ControlFlowGraph::unreachable_code_elimination() { return modified; } -bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access) { +bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access, + bool with_autodiff_after) { TI_AUTO_PROF; reaching_definition_analysis(after_lower_access); const int num_nodes = size(); bool modified = false; for (int i = 0; i < num_nodes; i++) { - if (nodes[i]->store_to_load_forwarding(after_lower_access)) + if (nodes[i]->store_to_load_forwarding(after_lower_access, + with_autodiff_after)) modified = true; } return modified; diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index ecb0f38d22311..2d234364a0ec2 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -76,7 +76,8 @@ class CFGNode { // Analyses and optimizations inside a CFGNode. void reaching_definition_analysis(bool after_lower_access); - bool store_to_load_forwarding(bool after_lower_access); + bool store_to_load_forwarding(bool after_lower_access, + bool with_autodiff_after); void gather_loaded_snodes(std::unordered_set &snodes) const; void live_variable_analysis(bool after_lower_access); bool dead_store_elimination(bool after_lower_access); @@ -145,7 +146,8 @@ class ControlFlowGraph { /** * Perform store-to-load forwarding and identical store elimination. */ - bool store_to_load_forwarding(bool after_lower_access); + bool store_to_load_forwarding(bool after_lower_access, + bool with_autodiff_after); /** * Perform dead store elimination and identical load elimination. diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index bf6be74a3f7a1..20771eb9d9203 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -34,6 +34,7 @@ bool simplify(IRNode *root, const CompileConfig &config); bool cfg_optimization( IRNode *root, bool after_lower_access, + bool with_autodiff_after, const std::optional &lva_config_opt = std::nullopt); bool alg_simp(IRNode *root, const CompileConfig &config); diff --git a/taichi/program/ir_bank.cpp b/taichi/program/ir_bank.cpp index 522debb2e33f7..e5b6df09bb826 100644 --- a/taichi/program/ir_bank.cpp +++ b/taichi/program/ir_bank.cpp @@ -114,7 +114,8 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) { } irpass::full_simplify(task_a, kernel->program->config, - {/*after_lower_access=*/false, kernel->program}); + {/*after_lower_access=*/false, + /*with_autodiff_after*/ false, kernel->program}); // For now, re_id is necessary for the hash to be correct. irpass::re_id(task_a); @@ -208,8 +209,9 @@ std::pair IRBank::optimize_dse( } ControlFlowGraph::LiveVarAnalysisConfig lva_config; lva_config.eliminable_snodes = {snodes.begin(), snodes.end()}; - const bool modified = irpass::cfg_optimization( - new_ir.get(), /*after_lower_access=*/false, lva_config); + const bool modified = + irpass::cfg_optimization(new_ir.get(), /*after_lower_access=*/false, + /*with_autodiff_after*/ false, lva_config); if (verbose) { TI_INFO(" DSE: after CFG, modified={}", modified); std::cout << std::flush; diff --git a/taichi/transforms/cfg_optimization.cpp b/taichi/transforms/cfg_optimization.cpp index 48a49a6942eae..7fb702312e84a 100644 --- a/taichi/transforms/cfg_optimization.cpp +++ b/taichi/transforms/cfg_optimization.cpp @@ -10,6 +10,7 @@ namespace irpass { bool cfg_optimization( IRNode *root, bool after_lower_access, + bool with_autodiff_after, const std::optional &lva_config_opt) { TI_AUTO_PROF; @@ -18,7 +19,7 @@ bool cfg_optimization( while (true) { bool modified = false; cfg->simplify_graph(); - if (cfg->store_to_load_forwarding(after_lower_access)) + if (cfg->store_to_load_forwarding(after_lower_access, with_autodiff_after)) modified = true; if (cfg->dead_store_elimination(after_lower_access, lva_config_opt)) modified = true; diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index cfe63000e7194..86321fb0b7a45 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -77,7 +77,10 @@ void compile_to_offloads(IRNode *ir, irpass::analysis::verify(ir); } - irpass::full_simplify(ir, config, {false, kernel->program}); + irpass::full_simplify( + ir, config, + {false, /*with_autodiff_after*/ autodiff_mode != AutodiffMode::kNone, + kernel->program}); print("Simplified I"); irpass::analysis::verify(ir); @@ -89,9 +92,12 @@ void compile_to_offloads(IRNode *ir, // Remove local atomics here so that we don't have to handle their gradients irpass::demote_atomics(ir, config); - irpass::full_simplify(ir, config, {false, kernel->program}); + irpass::full_simplify( + ir, config, {false, /*with_autodiff_after*/ true, kernel->program}); irpass::auto_diff(ir, config, autodiff_mode, ad_use_stack); - irpass::full_simplify(ir, config, {false, kernel->program}); + // TODO: Be carefull with the full_simplify when do high-order autodiff + irpass::full_simplify( + ir, config, {false, /*with_autodiff_after*/ false, kernel->program}); print("Gradient"); irpass::analysis::verify(ir); } @@ -106,7 +112,8 @@ void compile_to_offloads(IRNode *ir, print("Access flagged I"); irpass::analysis::verify(ir); - irpass::full_simplify(ir, config, {false, kernel->program}); + irpass::full_simplify( + ir, config, {false, /*with_autodiff_after*/ false, kernel->program}); print("Simplified II"); irpass::analysis::verify(ir); @@ -117,7 +124,7 @@ void compile_to_offloads(IRNode *ir, // TODO: This pass may be redundant as cfg_optimization() is already called // in full_simplify(). if (config.opt_level > 0 && config.cfg_optimization) { - irpass::cfg_optimization(ir, false); + irpass::cfg_optimization(ir, false, /*with_autodiff_after*/ false); print("Optimized by CFG"); irpass::analysis::verify(ir); } @@ -125,7 +132,8 @@ void compile_to_offloads(IRNode *ir, irpass::flag_access(ir); print("Access flagged II"); - irpass::full_simplify(ir, config, {false, kernel->program}); + irpass::full_simplify( + ir, config, {false, /*with_autodiff_after*/ false, kernel->program}); print("Simplified III"); irpass::analysis::verify(ir); } @@ -187,7 +195,8 @@ void offload_to_executable(IRNode *ir, if (config.make_mesh_block_local && config.arch == Arch::cuda) { irpass::make_mesh_block_local(ir, config, {kernel->get_name()}); print("Make mesh block local"); - irpass::full_simplify(ir, config, {false, kernel->program}); + irpass::full_simplify( + ir, config, {false, /*with_autodiff_after*/ false, kernel->program}); print("Simplified X"); } } @@ -235,7 +244,9 @@ void offload_to_executable(IRNode *ir, irpass::demote_operations(ir, config); print("Operations demoted"); - irpass::full_simplify(ir, config, {lower_global_access, kernel->program}); + irpass::full_simplify( + ir, config, + {lower_global_access, /*with_autodiff_after*/ false, kernel->program}); print("Simplified IV"); if (determine_ad_stack_size) { @@ -311,7 +322,8 @@ void compile_function(IRNode *ir, irpass::type_check(ir, config); print("Typechecked"); - irpass::full_simplify(ir, config, {false, func->program}); + irpass::full_simplify( + ir, config, {false, autodiff_mode != AutodiffMode::kNone, func->program}); print("Simplified"); irpass::analysis::verify(ir); } diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 2913f36cc810b..12d0f749b9b06 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -665,7 +665,8 @@ void full_simplify(IRNode *root, // not modified. if (config.opt_level > 0 && (first_iteration || modified) && config.cfg_optimization && - cfg_optimization(root, args.after_lower_access)) + cfg_optimization(root, args.after_lower_access, + args.with_autodiff_after)) modified = true; first_iteration = false; if (!modified) diff --git a/taichi/transforms/simplify.h b/taichi/transforms/simplify.h index fd4b58910e983..a6ab695ed4e52 100644 --- a/taichi/transforms/simplify.h +++ b/taichi/transforms/simplify.h @@ -11,6 +11,9 @@ class FullSimplifyPass : public Pass { struct Args { bool after_lower_access; + // Switch off some optimization in store forwarding if there is an autodiff + // pass after the full_simplify + bool with_autodiff_after; Program *program; }; }; diff --git a/tests/cpp/transforms/inlining_test.cpp b/tests/cpp/transforms/inlining_test.cpp index 6e77c4bf0f561..1ddfcfc9aefbd 100644 --- a/tests/cpp/transforms/inlining_test.cpp +++ b/tests/cpp/transforms/inlining_test.cpp @@ -51,7 +51,8 @@ TEST_F(InliningTest, ArgLoadOfArgLoad) { irpass::type_check(kernel_block, CompileConfig()); irpass::inlining(kernel_block, CompileConfig(), {}); - irpass::full_simplify(kernel_block, CompileConfig(), {false, prog_.get()}); + irpass::full_simplify(kernel_block, CompileConfig(), + {false, false, prog_.get()}); EXPECT_EQ(kernel_block->size(), 4); EXPECT_TRUE(irpass::analysis::same_statements(func_block, kernel_block)); diff --git a/tests/python/test_ad_for_fwd.py b/tests/python/test_ad_for_fwd.py index 3e05cd3dd97f1..0dc042050e824 100644 --- a/tests/python/test_ad_for_fwd.py +++ b/tests/python/test_ad_for_fwd.py @@ -144,7 +144,7 @@ def fib(): @test_utils.test() -def test_ad_fibonacci_index(): +def test_ad_fibonacci_index_fwd(): N = 5 M = 10 a = ti.field(ti.f32, shape=M) @@ -173,3 +173,42 @@ def fib(): is_fib = int(i in [1, 2, 3, 5, 8]) assert f.dual[i] == is_fib * N assert b[i] == is_fib * N + + +@test_utils.test(exclude=[ti.cc]) +def test_double_for_loops(): + N = 5 + a = ti.field(ti.f32, shape=N) + b = ti.field(ti.f32, shape=N) + c = ti.field(ti.i32, shape=N) + f = ti.field(ti.f32, shape=N) + ti.root.lazy_dual() + + @ti.kernel + def double_for(): + for i in range(N): + weight = 1.0 + for j in range(c[i]): + weight *= a[i] + s = 0.0 + for j in range(c[i] * 2): + s += weight + b[i] + f[i] = s + + a.fill(2) + b.fill(1) + + for i in range(N): + c[i] = i + + with ti.ad.FwdMode(loss=f, parameters=a, seed=[1.0 for _ in range(N)]): + double_for() + + for i in range(N): + assert f.dual[i] == 2 * i * i * 2**(i - 1) + + with ti.ad.FwdMode(loss=f, parameters=b, seed=[1.0 for _ in range(N)]): + double_for() + + for i in range(N): + assert f.dual[i] == 2 * i From e6cc0016b667dfb9418b9b4dcb91b35366249ee8 Mon Sep 17 00:00:00 2001 From: mingrui Date: Wed, 6 Jul 2022 23:35:15 +0800 Subject: [PATCH 02/11] update opt args --- taichi/transforms/compile_to_offloads.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 86321fb0b7a45..401956cbafaaf 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -79,7 +79,7 @@ void compile_to_offloads(IRNode *ir, irpass::full_simplify( ir, config, - {false, /*with_autodiff_after*/ autodiff_mode != AutodiffMode::kNone, + {false, /*with_autodiff_after*/ autodiff_mode != AutodiffMode::kForward, kernel->program}); print("Simplified I"); irpass::analysis::verify(ir); @@ -93,7 +93,9 @@ void compile_to_offloads(IRNode *ir, irpass::demote_atomics(ir, config); irpass::full_simplify( - ir, config, {false, /*with_autodiff_after*/ true, kernel->program}); + ir, config, + {false, /*with_autodiff_after*/ autodiff_mode != kForward, + kernel->program}); irpass::auto_diff(ir, config, autodiff_mode, ad_use_stack); // TODO: Be carefull with the full_simplify when do high-order autodiff irpass::full_simplify( From 4c05161989e4ae71b874a79b565b000f2a68c7b7 Mon Sep 17 00:00:00 2001 From: mingrui Date: Wed, 6 Jul 2022 23:36:51 +0800 Subject: [PATCH 03/11] update opt args --- taichi/transforms/compile_to_offloads.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 401956cbafaaf..cafd478f37b07 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -94,7 +94,7 @@ void compile_to_offloads(IRNode *ir, irpass::full_simplify( ir, config, - {false, /*with_autodiff_after*/ autodiff_mode != kForward, + {false, /*with_autodiff_after*/ autodiff_mode != AutodiffMode::kForward, kernel->program}); irpass::auto_diff(ir, config, autodiff_mode, ad_use_stack); // TODO: Be carefull with the full_simplify when do high-order autodiff From 53d1cf9a35687d8b6abf902b6ceb19ce9e20c174 Mon Sep 17 00:00:00 2001 From: mingrui Date: Wed, 6 Jul 2022 23:40:42 +0800 Subject: [PATCH 04/11] update opt args --- taichi/transforms/compile_to_offloads.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index cafd478f37b07..0585f6f83ef3b 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -79,7 +79,7 @@ void compile_to_offloads(IRNode *ir, irpass::full_simplify( ir, config, - {false, /*with_autodiff_after*/ autodiff_mode != AutodiffMode::kForward, + {false, /*with_autodiff_after*/ autodiff_mode == AutodiffMode::kForward, kernel->program}); print("Simplified I"); irpass::analysis::verify(ir); @@ -94,7 +94,7 @@ void compile_to_offloads(IRNode *ir, irpass::full_simplify( ir, config, - {false, /*with_autodiff_after*/ autodiff_mode != AutodiffMode::kForward, + {false, /*with_autodiff_after*/ autodiff_mode == AutodiffMode::kForward, kernel->program}); irpass::auto_diff(ir, config, autodiff_mode, ad_use_stack); // TODO: Be carefull with the full_simplify when do high-order autodiff From d6772409b08705d8db8679815795595e6590c5e9 Mon Sep 17 00:00:00 2001 From: mingrui Date: Thu, 7 Jul 2022 00:03:07 +0800 Subject: [PATCH 05/11] update opt args --- taichi/transforms/compile_to_offloads.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 0585f6f83ef3b..1c4e7fd44fddf 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -79,7 +79,7 @@ void compile_to_offloads(IRNode *ir, irpass::full_simplify( ir, config, - {false, /*with_autodiff_after*/ autodiff_mode == AutodiffMode::kForward, + {false, /*with_autodiff_after*/ autodiff_mode != AutodiffMode::kNone, kernel->program}); print("Simplified I"); irpass::analysis::verify(ir); @@ -94,7 +94,7 @@ void compile_to_offloads(IRNode *ir, irpass::full_simplify( ir, config, - {false, /*with_autodiff_after*/ autodiff_mode == AutodiffMode::kForward, + {false, /*with_autodiff_after*/ autodiff_mode != AutodiffMode::kNone, kernel->program}); irpass::auto_diff(ir, config, autodiff_mode, ad_use_stack); // TODO: Be carefull with the full_simplify when do high-order autodiff From c69d0719a3d74a730ea2d18d18f982840fa1a788 Mon Sep 17 00:00:00 2001 From: mingrui Date: Thu, 7 Jul 2022 00:13:30 +0800 Subject: [PATCH 06/11] exclude cc backend --- tests/python/test_ad_for.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/python/test_ad_for.py b/tests/python/test_ad_for.py index 1f76ce5c49cf5..cceea48764849 100644 --- a/tests/python/test_ad_for.py +++ b/tests/python/test_ad_for.py @@ -563,7 +563,9 @@ def test_inner_loops_local_variable(): assert x.grad[None] == 36.0 -@test_utils.test(require=ti.extension.adstack, ad_stack_size=0) +@test_utils.test(require=ti.extension.adstack, + ad_stack_size=0, + exclude=[ti.cc]) def test_more_inner_loops_local_variable_adaptive_stack_size_tape(): x = ti.field(dtype=float, shape=(), needs_grad=True) arr = ti.field(dtype=float, shape=(2), needs_grad=True) @@ -590,7 +592,9 @@ def test_more_inner_loops_local_variable(): assert x.grad[None] == 36.0 -@test_utils.test(require=ti.extension.adstack, ad_stack_size=32) +@test_utils.test(require=ti.extension.adstack, + ad_stack_size=32, + exclude=[ti.cc]) def test_more_inner_loops_local_variable_fixed_stack_size_tape(): x = ti.field(dtype=float, shape=(), needs_grad=True) arr = ti.field(dtype=float, shape=(2), needs_grad=True) From 011de883559533ca4f59b4a885fee28d47f68cc6 Mon Sep 17 00:00:00 2001 From: mingrui Date: Thu, 7 Jul 2022 15:22:38 +0800 Subject: [PATCH 07/11] no need to push zero to stack --- taichi/transforms/auto_diff.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 43002e9f584ff..445e0e9d59826 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -430,10 +430,10 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor { // Note that unlike AllocaStmt, AdStackAllocaStmt does NOT have an 0 as // initial value. Therefore here we push an initial 0 value. - auto zero = stack_alloca_ptr->insert_after_me( - Stmt::make(TypedConstant(dtype, 0))); - zero->insert_after_me( - Stmt::make(stack_alloca_ptr, zero)); + // auto zero = stack_alloca_ptr->insert_after_me( + // Stmt::make(TypedConstant(dtype, 0))); + // zero->insert_after_me( + // Stmt::make(stack_alloca_ptr, zero)); } } From 37f002b16fc6b6e4a188dfe26efb6ed2d338992e Mon Sep 17 00:00:00 2001 From: mingrui Date: Thu, 7 Jul 2022 15:24:17 +0800 Subject: [PATCH 08/11] no need to push zero to stack --- taichi/transforms/auto_diff.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 445e0e9d59826..da11cbe9e9075 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -424,7 +424,7 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor { if (is_stack_needed) { auto dtype = alloc->ret_type; auto stack_alloca = Stmt::make(dtype, ad_stack_size); - auto stack_alloca_ptr = stack_alloca.get(); + // auto stack_alloca_ptr = stack_alloca.get(); alloc->replace_with(std::move(stack_alloca)); From 0622d8b88d0027c3d2787e2905e3b66029e226f1 Mon Sep 17 00:00:00 2001 From: mingrui Date: Tue, 19 Jul 2022 16:19:29 +0800 Subject: [PATCH 09/11] exclude ConstantStmt when generating adjoint --- taichi/transforms/auto_diff.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index da11cbe9e9075..7e6c04c7bdb02 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -719,7 +719,7 @@ class MakeAdjoint : public ADTransform { } Stmt *adjoint(Stmt *stmt) { - if (!is_real(stmt->ret_type)) { + if (!is_real(stmt->ret_type) || stmt->is()) { return constant(0); } if (adjoint_stmt.find(stmt) == adjoint_stmt.end()) { From 17392eec299af95327c80b316b3fd34fd8e90d5b Mon Sep 17 00:00:00 2001 From: mingrui Date: Tue, 19 Jul 2022 16:27:53 +0800 Subject: [PATCH 10/11] recover the initial zero for ad stack --- taichi/transforms/auto_diff.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 7e6c04c7bdb02..2f7087e0afa3f 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -424,16 +424,16 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor { if (is_stack_needed) { auto dtype = alloc->ret_type; auto stack_alloca = Stmt::make(dtype, ad_stack_size); - // auto stack_alloca_ptr = stack_alloca.get(); + auto stack_alloca_ptr = stack_alloca.get(); alloc->replace_with(std::move(stack_alloca)); // Note that unlike AllocaStmt, AdStackAllocaStmt does NOT have an 0 as // initial value. Therefore here we push an initial 0 value. - // auto zero = stack_alloca_ptr->insert_after_me( - // Stmt::make(TypedConstant(dtype, 0))); - // zero->insert_after_me( - // Stmt::make(stack_alloca_ptr, zero)); + auto zero = stack_alloca_ptr->insert_after_me( + Stmt::make(TypedConstant(dtype, 0))); + zero->insert_after_me( + Stmt::make(stack_alloca_ptr, zero)); } } From 1e72b0f2bd5de19cc57c2b884eab9de1c4c84010 Mon Sep 17 00:00:00 2001 From: mingrui Date: Thu, 21 Jul 2022 17:06:27 +0800 Subject: [PATCH 11/11] update --- taichi/ir/control_flow_graph.cpp | 44 +++++++++++------------ taichi/ir/control_flow_graph.h | 6 ++-- taichi/ir/transforms.h | 2 +- taichi/program/ir_bank.cpp | 4 +-- taichi/transforms/cfg_optimization.cpp | 4 +-- taichi/transforms/compile_to_offloads.cpp | 26 +++++++------- taichi/transforms/simplify.cpp | 2 +- taichi/transforms/simplify.h | 2 +- 8 files changed, 41 insertions(+), 49 deletions(-) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 65c0af9e26c67..3c4ddfedf2cac 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -256,7 +256,7 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access) { } bool CFGNode::store_to_load_forwarding(bool after_lower_access, - bool with_autodiff_after) { + bool autodiff_enabled) { bool modified = false; for (int i = begin_location; i < end_location; i++) { // Store-to-load forwarding @@ -275,7 +275,7 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access, result = get_store_forwarding_data(alloca, i); } } else if (auto global_load = stmt->cast()) { - if (!after_lower_access && !with_autodiff_after) { + if (!after_lower_access && !autodiff_enabled) { result = get_store_forwarding_data(global_load->src, i); } } @@ -298,33 +298,29 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access, // Identical store elimination if (auto local_store = stmt->cast()) { result = get_store_forwarding_data(local_store->dest, i); - if (result) { - if (result->is()) { - if (!with_autodiff_after) { - // special case of alloca (initialized to 0) - if (auto stored_data = local_store->val->cast()) { - bool all_zero = true; - for (auto &val : stored_data->val.data) { - if (!val.equal_value(0)) { - all_zero = false; - break; - } - } - if (all_zero) { - erase(i); // This causes end_location-- - i--; // to cancel i++ in the for loop - modified = true; - } + if (result && result->is() && !autodiff_enabled) { + // special case of alloca (initialized to 0) + if (auto stored_data = local_store->val->cast()) { + bool all_zero = true; + for (auto &val : stored_data->val.data) { + if (!val.equal_value(0)) { + all_zero = false; + break; } } - } else { - // not alloca - if (irpass::analysis::same_value(result, local_store->val)) { + if (all_zero) { erase(i); // This causes end_location-- i--; // to cancel i++ in the for loop modified = true; } } + } else { + // not alloca + if (irpass::analysis::same_value(result, local_store->val)) { + erase(i); // This causes end_location-- + i--; // to cancel i++ in the for loop + modified = true; + } } } else if (auto global_store = stmt->cast()) { if (!after_lower_access) { @@ -847,14 +843,14 @@ bool ControlFlowGraph::unreachable_code_elimination() { } bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access, - bool with_autodiff_after) { + bool autodiff_enabled) { TI_AUTO_PROF; reaching_definition_analysis(after_lower_access); const int num_nodes = size(); bool modified = false; for (int i = 0; i < num_nodes; i++) { if (nodes[i]->store_to_load_forwarding(after_lower_access, - with_autodiff_after)) + autodiff_enabled)) modified = true; } return modified; diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index 2d234364a0ec2..1776a57ea4eea 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -76,8 +76,7 @@ class CFGNode { // Analyses and optimizations inside a CFGNode. void reaching_definition_analysis(bool after_lower_access); - bool store_to_load_forwarding(bool after_lower_access, - bool with_autodiff_after); + bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled); void gather_loaded_snodes(std::unordered_set &snodes) const; void live_variable_analysis(bool after_lower_access); bool dead_store_elimination(bool after_lower_access); @@ -146,8 +145,7 @@ class ControlFlowGraph { /** * Perform store-to-load forwarding and identical store elimination. */ - bool store_to_load_forwarding(bool after_lower_access, - bool with_autodiff_after); + bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled); /** * Perform dead store elimination and identical load elimination. diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index 20771eb9d9203..6ef0794b1d17a 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -34,7 +34,7 @@ bool simplify(IRNode *root, const CompileConfig &config); bool cfg_optimization( IRNode *root, bool after_lower_access, - bool with_autodiff_after, + bool autodiff_enabled, const std::optional &lva_config_opt = std::nullopt); bool alg_simp(IRNode *root, const CompileConfig &config); diff --git a/taichi/program/ir_bank.cpp b/taichi/program/ir_bank.cpp index e5b6df09bb826..a155beaca4cd7 100644 --- a/taichi/program/ir_bank.cpp +++ b/taichi/program/ir_bank.cpp @@ -115,7 +115,7 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) { irpass::full_simplify(task_a, kernel->program->config, {/*after_lower_access=*/false, - /*with_autodiff_after*/ false, kernel->program}); + /*autodiff_enabled*/ false, kernel->program}); // For now, re_id is necessary for the hash to be correct. irpass::re_id(task_a); @@ -211,7 +211,7 @@ std::pair IRBank::optimize_dse( lva_config.eliminable_snodes = {snodes.begin(), snodes.end()}; const bool modified = irpass::cfg_optimization(new_ir.get(), /*after_lower_access=*/false, - /*with_autodiff_after*/ false, lva_config); + /*autodiff_enabled*/ false, lva_config); if (verbose) { TI_INFO(" DSE: after CFG, modified={}", modified); std::cout << std::flush; diff --git a/taichi/transforms/cfg_optimization.cpp b/taichi/transforms/cfg_optimization.cpp index 7fb702312e84a..93efc651a411d 100644 --- a/taichi/transforms/cfg_optimization.cpp +++ b/taichi/transforms/cfg_optimization.cpp @@ -10,7 +10,7 @@ namespace irpass { bool cfg_optimization( IRNode *root, bool after_lower_access, - bool with_autodiff_after, + bool autodiff_enabled, const std::optional &lva_config_opt) { TI_AUTO_PROF; @@ -19,7 +19,7 @@ bool cfg_optimization( while (true) { bool modified = false; cfg->simplify_graph(); - if (cfg->store_to_load_forwarding(after_lower_access, with_autodiff_after)) + if (cfg->store_to_load_forwarding(after_lower_access, autodiff_enabled)) modified = true; if (cfg->dead_store_elimination(after_lower_access, lva_config_opt)) modified = true; diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 1c4e7fd44fddf..76562af80156f 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -79,7 +79,7 @@ void compile_to_offloads(IRNode *ir, irpass::full_simplify( ir, config, - {false, /*with_autodiff_after*/ autodiff_mode != AutodiffMode::kNone, + {false, /*autodiff_enabled*/ autodiff_mode != AutodiffMode::kNone, kernel->program}); print("Simplified I"); irpass::analysis::verify(ir); @@ -92,14 +92,12 @@ void compile_to_offloads(IRNode *ir, // Remove local atomics here so that we don't have to handle their gradients irpass::demote_atomics(ir, config); - irpass::full_simplify( - ir, config, - {false, /*with_autodiff_after*/ autodiff_mode != AutodiffMode::kNone, - kernel->program}); + irpass::full_simplify(ir, config, + {false, /*autodiff_enabled*/ true, kernel->program}); irpass::auto_diff(ir, config, autodiff_mode, ad_use_stack); // TODO: Be carefull with the full_simplify when do high-order autodiff - irpass::full_simplify( - ir, config, {false, /*with_autodiff_after*/ false, kernel->program}); + irpass::full_simplify(ir, config, + {false, /*autodiff_enabled*/ false, kernel->program}); print("Gradient"); irpass::analysis::verify(ir); } @@ -114,8 +112,8 @@ void compile_to_offloads(IRNode *ir, print("Access flagged I"); irpass::analysis::verify(ir); - irpass::full_simplify( - ir, config, {false, /*with_autodiff_after*/ false, kernel->program}); + irpass::full_simplify(ir, config, + {false, /*autodiff_enabled*/ false, kernel->program}); print("Simplified II"); irpass::analysis::verify(ir); @@ -126,7 +124,7 @@ void compile_to_offloads(IRNode *ir, // TODO: This pass may be redundant as cfg_optimization() is already called // in full_simplify(). if (config.opt_level > 0 && config.cfg_optimization) { - irpass::cfg_optimization(ir, false, /*with_autodiff_after*/ false); + irpass::cfg_optimization(ir, false, /*autodiff_enabled*/ false); print("Optimized by CFG"); irpass::analysis::verify(ir); } @@ -134,8 +132,8 @@ void compile_to_offloads(IRNode *ir, irpass::flag_access(ir); print("Access flagged II"); - irpass::full_simplify( - ir, config, {false, /*with_autodiff_after*/ false, kernel->program}); + irpass::full_simplify(ir, config, + {false, /*autodiff_enabled*/ false, kernel->program}); print("Simplified III"); irpass::analysis::verify(ir); } @@ -198,7 +196,7 @@ void offload_to_executable(IRNode *ir, irpass::make_mesh_block_local(ir, config, {kernel->get_name()}); print("Make mesh block local"); irpass::full_simplify( - ir, config, {false, /*with_autodiff_after*/ false, kernel->program}); + ir, config, {false, /*autodiff_enabled*/ false, kernel->program}); print("Simplified X"); } } @@ -248,7 +246,7 @@ void offload_to_executable(IRNode *ir, irpass::full_simplify( ir, config, - {lower_global_access, /*with_autodiff_after*/ false, kernel->program}); + {lower_global_access, /*autodiff_enabled*/ false, kernel->program}); print("Simplified IV"); if (determine_ad_stack_size) { diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 12d0f749b9b06..af09d42a21df1 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -666,7 +666,7 @@ void full_simplify(IRNode *root, if (config.opt_level > 0 && (first_iteration || modified) && config.cfg_optimization && cfg_optimization(root, args.after_lower_access, - args.with_autodiff_after)) + args.autodiff_enabled)) modified = true; first_iteration = false; if (!modified) diff --git a/taichi/transforms/simplify.h b/taichi/transforms/simplify.h index a6ab695ed4e52..ada1cec40bd46 100644 --- a/taichi/transforms/simplify.h +++ b/taichi/transforms/simplify.h @@ -13,7 +13,7 @@ class FullSimplifyPass : public Pass { bool after_lower_access; // Switch off some optimization in store forwarding if there is an autodiff // pass after the full_simplify - bool with_autodiff_after; + bool autodiff_enabled; Program *program; }; };