From 11f0cb9fc1072072a0a50edcbc583b847a36bc0b Mon Sep 17 00:00:00 2001 From: Mingrui Zhang <33411325+erizmr@users.noreply.github.com> Date: Wed, 13 Jul 2022 09:52:35 +0800 Subject: [PATCH] [autodiff] Add ternary operators for forward mode (#5405) * [autodiff] Add ternary operators for forward mode * add ternary test case for forward mode --- taichi/transforms/auto_diff.cpp | 9 ++++++++ tests/python/test_ad_basics.py | 38 +++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 3227497478ca1..43002e9f584ff 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -1200,6 +1200,15 @@ class MakeDual : public ADTransform { } } + void visit(TernaryOpStmt *stmt) override { + TI_ASSERT(stmt->op_type == TernaryOpType::select); + auto zero = insert(TypedConstant(stmt->ret_type)); + accumulate(stmt, insert(TernaryOpType::select, stmt->op1, + load(dual(stmt->op2)), zero)); + accumulate(stmt, insert(TernaryOpType::select, stmt->op1, + zero, load(dual(stmt->op3)))); + } + void visit(IfStmt *if_stmt) override { if (if_stmt->true_statements) { std::vector true_statements; diff --git a/tests/python/test_ad_basics.py b/tests/python/test_ad_basics.py index b6c1687b89632..5cb83ac56767e 100644 --- a/tests/python/test_ad_basics.py +++ b/tests/python/test_ad_basics.py @@ -288,6 +288,44 @@ def func(): assert y.grad[i] == (not i % 2) * 1.0 +@test_utils.test() +def test_select_fwd(): + N = 5 + loss = ti.field(ti.f32, shape=N) + x = ti.field(ti.f32, shape=N) + y = ti.field(ti.f32, shape=N) + ti.root.lazy_dual() + + for i in range(N): + x[i] = i + y[i] = -i + + @ti.kernel + def func(): + for i in range(N): + loss[i] = ti.select(i % 2, x[i], y[i]) + + with ti.ad.FwdMode(loss=loss, parameters=x, seed=[1.0 for _ in range(N)]): + func() + + for i in range(N): + if i % 2: + assert loss[i] == i + else: + assert loss[i] == -i + assert loss.dual[i] == i % 2 * 1.0 + + with ti.ad.FwdMode(loss=loss, parameters=y, seed=[1.0 for _ in range(N)]): + func() + + for i in range(N): + if i % 2: + assert loss[i] == i + else: + assert loss[i] == -i + assert loss.dual[i] == (not i % 2) * 1.0 + + @test_utils.test() def test_obey_kernel_simplicity(): x = ti.field(ti.f32)