Skip to content

Commit

Permalink
[Doc] Add introduction to forward mode autodiff (#5680)
Browse files Browse the repository at this point in the history
* [autodiff] Add introduction to forward mode autodiff

* Apply suggestions from code review

Co-authored-by: Ailing  <ailzhang@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Ailing  <ailzhang@users.noreply.github.com>
  • Loading branch information
erizmr and ailzhang committed Aug 10, 2022
1 parent eb0ab1c commit 26ef559
Showing 1 changed file with 76 additions and 0 deletions.
76 changes: 76 additions & 0 deletions docs/lang/articles/differentiable/differentiable_programming.md
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,79 @@ Check out [the DiffTaichi paper](https://arxiv.org/pdf/1910.00935.pdf)
and [video](https://www.youtube.com/watch?v=Z1xvAZve9aE) to learn more
about Taichi differentiable programming.
:::


## Forward-Mode Autodiff

There are two modes of automatic differentiation, forward and reverse mode. The forward mode provides a function to compute Jacobian-Vector Product (JVP), which can compute one column of the Jacobian matrix at a time. The reverse mode supports computing Vector-Jacobian Product (VJP), i.e., one row of the Jacobian matrix at a time. Therefore, for functions which have more inputs than outputs, reverse mode is more efficient. The `ti.ad.Tape` and `kernel.grad()` are built on the reverse mode. The forward mode is more efficient when handling functions whose outputs are more than inputs. Taichi autodiff also supports forward mode.

### Using `ti.ad.FwdMode`
The usage of `ti.ad.FwdMode` is very similar to `ti.ad.Tape`. Here we reuse the example for reverse mode above for an explanation.
1. Enable `needs_dual=True` option when declaring fields involved in the derivative chain.
2. Use context manager with `ti.ad.FwdMode(loss=y, param=x)`: to capture the kernel invocations which you want to automatically differentiate. The `loss` and `param` are the output and input of the function respectively.
3. Now dy/dx value at current x is available at function output `y.dual[None]`.
The following code snippet explains the steps above:

```python
import taichi as ti
ti.init()

x = ti.field(dtype=ti.f32, shape=(), needs_dual=True)
y = ti.field(dtype=ti.f32, shape=(), needs_dual=True)


@ti.kernel
def compute_y():
y[None] = ti.sin(x[None])


with ti.ad.FwdMode(loss=y, param=x):
compute_y()

print('dy/dx =', y.dual[None], ' at x =', x[None])
```

:::note
The `dual` here indicates `dual number`in math. The reason for using the name is that forwar-mode autodiff is equivalent to evaluating function with dual numbers.
:::

:::note
The `ti.ad.FwdMode` automatically clears the dual field of `loss`.
:::

ti.ad.FwdMode support multiple inputs and outputs. The param can be a N-D field and the loss can be an individual or a list of N-D fields. The argument `seed` is the 'vector' in Jacobian-vector product, which used to control the parameter that is computed derivative with respect to. Here we show three cases with multiple inputs and outputs. With `seed=[1.0, 0.0] `or `seed=[0.0, 1.0]` , we can compute the derivatives solely with respect to `x_0` or `x_1`.

```python
import taichi as ti
ti.init()
N_param = 2
N_loss = 5
x = ti.field(dtype=ti.f32, shape=N_param, needs_dual=True)
y = ti.field(dtype=ti.f32, shape=N_loss, needs_dual=True)


@ti.kernel
def compute_y():
for i in range(N_loss):
for j in range(N_param):
y[i] += i * ti.sin(x[j])


# Compute derivatives respect to x_0
with ti.ad.FwdMode(loss=y, param=x, seed=[1.0, 0.0]):
compute_y()
print('dy/dx_0 =', y.dual, ' at x_0 =', x[0])

# Compute derivatives respect to x_1
with ti.ad.FwdMode(loss=y, param=x, seed=[0.0, 1.0]):
compute_y()
print('dy/dx_1 =', y.dual, ' at x_1 =', x[1])
```

:::note
The `seed` argument is required if the `param` is not a scalar field.
:::

:::tip
Similar to reverse mode autodiff, Taichi provides an API `ti.root.lazy_dual()` that automatically places the dual fields following the layout of their primal fields.
:::

0 comments on commit 26ef559

Please sign in to comment.