diff --git a/python/taichi/ad/_ad.py b/python/taichi/ad/_ad.py index c0e04934b461c..04117def6c0b5 100644 --- a/python/taichi/ad/_ad.py +++ b/python/taichi/ad/_ad.py @@ -283,6 +283,11 @@ def shape_flatten(shape): else: assert parameters_shape_flatten == len(self.seed) + # Clear gradients + if self.clear_gradients: + # TODO: the clear gradients should be controlled to clear adjoint/dual/adjoint_visited respectively + clear_all_gradients() + # Set seed for each variable if len(self.seed) == 1: if len(self.param.shape) == 0: @@ -294,11 +299,6 @@ def shape_flatten(shape): else: self.param.dual.from_numpy(np.array(self.seed, dtype=np.float32)) - # Clear gradients - if self.clear_gradients: - for ls in self.loss: - ls.dual.fill(0) - # Attach the context manager to the runtime self.runtime.fwd_mode_manager = self diff --git a/tests/python/test_ad_basics_fwd.py b/tests/python/test_ad_basics_fwd.py index 7fe7932a71f67..6eed8d8a9a7a0 100644 --- a/tests/python/test_ad_basics_fwd.py +++ b/tests/python/test_ad_basics_fwd.py @@ -103,3 +103,23 @@ def func(): with ti.ad.FwdMode(loss=d, param=c): func() + + +@test_utils.test() +def test_clear_all_dual_field(): + x = ti.field(float, shape=(), needs_dual=True) + y = ti.field(float, shape=(), needs_dual=True) + loss = ti.field(float, shape=(), needs_dual=True) + + x[None] = 2.0 + y[None] = 3.0 + + @ti.kernel + def clear_dual_test(): + y[None] = x[None]**2 + loss[None] += y[None] + + for _ in range(5): + with ti.ad.FwdMode(loss=loss, param=x): + clear_dual_test() + assert y.dual[None] == 4.0