Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions diffrax/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,17 @@ class BacksolveAdjoint(AbstractAdjoint):
"optimise-then-discretise", the "continuous adjoint method" or simply the "adjoint
method".

This method implies very low memory usage, but is usually relatively slow, and the
This method implies very low memory usage, but the
computed gradients will only be approximate. As such other methods are generally
preferred unless exceeding memory is a concern.

This will compute gradients with respect to the `terms`, `y0` and `args` arguments
passed to [`diffrax.diffeqsolve`][]. If you attempt to compute gradients with
respect to anything else (for example `t0`, or arguments passed via closure), then
a `CustomVJPException` will be raised. See also
[this FAQ](../../further_details/faq/#im-getting-a-customvjpexception)
entry.

!!! note

This was popularised by [this paper](https://arxiv.org/abs/1806.07366). For
Expand All @@ -290,7 +297,7 @@ class BacksolveAdjoint(AbstractAdjoint):

Using this method prevents computing forward-mode autoderivatives of
[`diffrax.diffeqsolve`][]. (That is to say, `jax.jvp` will not work.)
"""
""" # noqa: E501

kwargs: Dict[str, Any]

Expand Down
59 changes: 38 additions & 21 deletions docs/further_details/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

Try switching to 64-bit precision. (Instead of the 32-bit that is the default in JAX.) [See here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).

### I'm getting zero gradient for one of my model parameters.
### I'm getting a `CustomVJPException`.

This can happen if you use [`diffrax.BacksolveAdjoint`][] incorrectly.

Expand All @@ -14,39 +14,56 @@ Gradients will be computed for:
- Everything in the `y0` PyTree passed to `diffeqsolve(..., y0=y0)`.
- Everything in the `terms` PyTree passed to `diffeqsolve(terms, ...)`.

Attempting to compute gradients with respect to anything else will result in this exception.

!!! example

Gradients through `args` and `y0` are self-explanatory. Meanwhile, a common example of computing gradients through `terms` is if using an [Equinox](https://github.com/patrick-kidger/equinox) module to represent a parameterised vector field. For example:
Here is a minimal example of **wrong** code that will raise this exception.

```python
from diffrax import BacksolveAdjoint, diffeqsolve, Euler, ODETerm
import equinox as eqx
import diffrax
import jax.numpy as jnp
import jax.random as jr

class Func(eqx.Module):
mlp: eqx.nn.MLP
mlp = eqx.nn.MLP(1, 1, 8, 2, key=jr.PRNGKey(0))

def __call__(self, t, y, args):
return self.mlp(y)
@eqx.filter_jit
@eqx.filter_value_and_grad
def run(model):
def f(t, y, args): # `model` captured via closure; is not part of the `terms` PyTree.
return model(y)
sol = diffeqsolve(ODETerm(f), Euler(), 0, 1, 0.1, jnp.array([1.0]),
adjoint=BacksolveAdjoint())
return jnp.sum(sol.ys)

mlp = eqx.nn.MLP(...)
func = Func(mlp)
term = diffrax.ODETerm(func)
diffrax.diffeqsolve(term, ..., adjoint=diffrax.BacksolveAdjoint())
run(mlp)
```

In this case `diffrax.ODETerm`, `Func` and `eqx.nn.MLP` are all PyTrees, so all of the parameters inside `mlp` are visible to `diffeqsolve` and it can compute gradients with respect to them.
!!! example

The corrected version of the previous example is as follows. In this case, the model is properly part of the PyTree structure of `terms`.

```python
from diffrax import BacksolveAdjoint, diffeqsolve, Euler, ODETerm
import equinox as eqx
import jax.numpy as jnp
import jax.random as jr

However if you were to do:
mlp = eqx.nn.MLP(1, 1, 8, 2, key=jr.PRNGKey(0))

```python
model = ...
class VectorField(eqx.Module):
model: eqx.Module

def func(t, y, args):
return model(y)
def __call__(self, t, y, args):
return self.model(y)

term = diffrax.ODETerm(func)
diffrax.diffeqsolve(term, ..., adjoint=diffrax.BacksolveAdjoint())
```
@eqx.filter_jit
@eqx.filter_value_and_grad
def run(model):
f = VectorField(model)
sol = diffeqsolve(ODETerm(f), Euler(), 0, 1, 0.1, jnp.array([1.0]), adjoint=BacksolveAdjoint())
return jnp.sum(sol.ys)

then the parameters of `model` are not visible to `diffeqsolve` so gradients will not be computed with respect to them.
run(mlp)
```
52 changes: 52 additions & 0 deletions test/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import pytest

from helpers import shaped_allclose
Expand Down Expand Up @@ -141,3 +142,54 @@ def solve(y0):
return jnp.sum(sol.ys)

jax.grad(solve)(2.0)


def test_closure_errors():
mlp = eqx.nn.MLP(1, 1, 8, 2, key=jrandom.PRNGKey(0))

@eqx.filter_jit
@eqx.filter_value_and_grad
def run(model):
def f(t, y, args):
return model(y)

sol = diffrax.diffeqsolve(
diffrax.ODETerm(f),
diffrax.Euler(),
0,
1,
0.1,
jnp.array([1.0]),
adjoint=diffrax.BacksolveAdjoint(),
)
return jnp.sum(sol.ys)

with pytest.raises(jax.interpreters.ad.CustomVJPException):
run(mlp)


def test_closure_fixed():
mlp = eqx.nn.MLP(1, 1, 8, 2, key=jrandom.PRNGKey(0))

class VectorField(eqx.Module):
model: eqx.Module

def __call__(self, t, y, args):
return self.model(y)

@eqx.filter_jit
@eqx.filter_value_and_grad
def run(model):
f = VectorField(model)
sol = diffrax.diffeqsolve(
diffrax.ODETerm(f),
diffrax.Euler(),
0,
1,
0.1,
jnp.array([1.0]),
adjoint=diffrax.BacksolveAdjoint(),
)
return jnp.sum(sol.ys)

run(mlp)