From 6b3b92f5356d92b6a1102f4a9597d70442957fe5 Mon Sep 17 00:00:00 2001 From: mingrui Date: Tue, 12 Jul 2022 12:04:53 +0800 Subject: [PATCH 1/2] [autodiff] Add test for TernaryOpStmt in reverse ad --- tests/python/test_ad_basics.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/python/test_ad_basics.py b/tests/python/test_ad_basics.py index 371d6f3cdfb10..d1923fc769ab6 100644 --- a/tests/python/test_ad_basics.py +++ b/tests/python/test_ad_basics.py @@ -259,6 +259,35 @@ def test_pow_f64(tifunc, npfunc): grad_test_fwd(tifunc, npfunc) +@test_utils.test() +def test_select(): + N = 5 + loss = ti.field(ti.f32, shape=N) + x = ti.field(ti.f32, shape=N) + y = ti.field(ti.f32, shape=N) + ti.root.lazy_grad() + + for i in range(N): + x[i] = i + y[i] = -i + loss.grad[i] = 1.0 + + @ti.kernel + def func(): + for i in range(N): + loss[i] += ti.select(i % 2, x[i], y[i]) + + func() + func.grad() + for i in range(N): + if i % 2: + loss[i] = i + else: + loss[i] = -i + assert x.grad[i] == i % 2 * 1.0 + assert y.grad[i] == (not i % 2) * 1.0 + + @test_utils.test() def test_obey_kernel_simplicity(): x = ti.field(ti.f32) From 882c5c40e624e4d7c3ac79f43332655357343143 Mon Sep 17 00:00:00 2001 From: mingrui Date: Tue, 12 Jul 2022 16:10:46 +0800 Subject: [PATCH 2/2] update the test case --- tests/python/test_ad_basics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/test_ad_basics.py b/tests/python/test_ad_basics.py index d1923fc769ab6..b6c1687b89632 100644 --- a/tests/python/test_ad_basics.py +++ b/tests/python/test_ad_basics.py @@ -281,9 +281,9 @@ def func(): func.grad() for i in range(N): if i % 2: - loss[i] = i + assert loss[i] == i else: - loss[i] = -i + assert loss[i] == -i assert x.grad[i] == i % 2 * 1.0 assert y.grad[i] == (not i % 2) * 1.0