Skip to content

Commit

Permalink
[autodiff] Add ternary operators for forward mode (#5405)
Browse files Browse the repository at this point in the history
* [autodiff] Add ternary operators for forward mode

* add ternary test case for forward mode
  • Loading branch information
erizmr committed Jul 13, 2022
1 parent c40e19a commit 11f0cb9
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
9 changes: 9 additions & 0 deletions taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,15 @@ class MakeDual : public ADTransform {
}
}

void visit(TernaryOpStmt *stmt) override {
TI_ASSERT(stmt->op_type == TernaryOpType::select);
auto zero = insert<ConstStmt>(TypedConstant(stmt->ret_type));
accumulate(stmt, insert<TernaryOpStmt>(TernaryOpType::select, stmt->op1,
load(dual(stmt->op2)), zero));
accumulate(stmt, insert<TernaryOpStmt>(TernaryOpType::select, stmt->op1,
zero, load(dual(stmt->op3))));
}

void visit(IfStmt *if_stmt) override {
if (if_stmt->true_statements) {
std::vector<Stmt *> true_statements;
Expand Down
38 changes: 38 additions & 0 deletions tests/python/test_ad_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,44 @@ def func():
assert y.grad[i] == (not i % 2) * 1.0


@test_utils.test()
def test_select_fwd():
N = 5
loss = ti.field(ti.f32, shape=N)
x = ti.field(ti.f32, shape=N)
y = ti.field(ti.f32, shape=N)
ti.root.lazy_dual()

for i in range(N):
x[i] = i
y[i] = -i

@ti.kernel
def func():
for i in range(N):
loss[i] = ti.select(i % 2, x[i], y[i])

with ti.ad.FwdMode(loss=loss, parameters=x, seed=[1.0 for _ in range(N)]):
func()

for i in range(N):
if i % 2:
assert loss[i] == i
else:
assert loss[i] == -i
assert loss.dual[i] == i % 2 * 1.0

with ti.ad.FwdMode(loss=loss, parameters=y, seed=[1.0 for _ in range(N)]):
func()

for i in range(N):
if i % 2:
assert loss[i] == i
else:
assert loss[i] == -i
assert loss.dual[i] == (not i % 2) * 1.0


@test_utils.test()
def test_obey_kernel_simplicity():
x = ti.field(ti.f32)
Expand Down

0 comments on commit 11f0cb9

Please sign in to comment.