From 67c3bd0cf98bca91943aeba5453d10c563713c41 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 8 Jul 2022 14:25:39 +0100 Subject: [PATCH] Updated BacksolveAdjoint docs --- diffrax/adjoint.py | 11 +++++-- docs/further_details/faq.md | 59 ++++++++++++++++++++++++------------- test/test_adjoint.py | 52 ++++++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+), 23 deletions(-) diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index 79dae8e8..7955d991 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -276,10 +276,17 @@ class BacksolveAdjoint(AbstractAdjoint): "optimise-then-discretise", the "continuous adjoint method" or simply the "adjoint method". - This method implies very low memory usage, but is usually relatively slow, and the + This method implies very low memory usage, but the computed gradients will only be approximate. As such other methods are generally preferred unless exceeding memory is a concern. + This will compute gradients with respect to the `terms`, `y0` and `args` arguments + passed to [`diffrax.diffeqsolve`][]. If you attempt to compute gradients with + respect to anything else (for example `t0`, or arguments passed via closure), then + a `CustomVJPException` will be raised. See also + [this FAQ](../../further_details/faq/#im-getting-a-customvjpexception) + entry. + !!! note This was popularised by [this paper](https://arxiv.org/abs/1806.07366). For @@ -290,7 +297,7 @@ class BacksolveAdjoint(AbstractAdjoint): Using this method prevents computing forward-mode autoderivatives of [`diffrax.diffeqsolve`][]. (That is to say, `jax.jvp` will not work.) - """ + """ # noqa: E501 kwargs: Dict[str, Any] diff --git a/docs/further_details/faq.md b/docs/further_details/faq.md index e4c697c9..626926a4 100644 --- a/docs/further_details/faq.md +++ b/docs/further_details/faq.md @@ -4,7 +4,7 @@ Try switching to 64-bit precision. (Instead of the 32-bit that is the default in JAX.) [See here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). -### I'm getting zero gradient for one of my model parameters. +### I'm getting a `CustomVJPException`. This can happen if you use [`diffrax.BacksolveAdjoint`][] incorrectly. @@ -14,39 +14,56 @@ Gradients will be computed for: - Everything in the `y0` PyTree passed to `diffeqsolve(..., y0=y0)`. - Everything in the `terms` PyTree passed to `diffeqsolve(terms, ...)`. +Attempting to compute gradients with respect to anything else will result in this exception. !!! example - Gradients through `args` and `y0` are self-explanatory. Meanwhile, a common example of computing gradients through `terms` is if using an [Equinox](https://github.com/patrick-kidger/equinox) module to represent a parameterised vector field. For example: + Here is a minimal example of **wrong** code that will raise this exception. ```python + from diffrax import BacksolveAdjoint, diffeqsolve, Euler, ODETerm import equinox as eqx - import diffrax + import jax.numpy as jnp + import jax.random as jr - class Func(eqx.Module): - mlp: eqx.nn.MLP + mlp = eqx.nn.MLP(1, 1, 8, 2, key=jr.PRNGKey(0)) - def __call__(self, t, y, args): - return self.mlp(y) + @eqx.filter_jit + @eqx.filter_value_and_grad + def run(model): + def f(t, y, args): # `model` captured via closure; is not part of the `terms` PyTree. + return model(y) + sol = diffeqsolve(ODETerm(f), Euler(), 0, 1, 0.1, jnp.array([1.0]), + adjoint=BacksolveAdjoint()) + return jnp.sum(sol.ys) - mlp = eqx.nn.MLP(...) - func = Func(mlp) - term = diffrax.ODETerm(func) - diffrax.diffeqsolve(term, ..., adjoint=diffrax.BacksolveAdjoint()) + run(mlp) ``` - In this case `diffrax.ODETerm`, `Func` and `eqx.nn.MLP` are all PyTrees, so all of the parameters inside `mlp` are visible to `diffeqsolve` and it can compute gradients with respect to them. +!!! example + + The corrected version of the previous example is as follows. In this case, the model is properly part of the PyTree structure of `terms`. + + ```python + from diffrax import BacksolveAdjoint, diffeqsolve, Euler, ODETerm + import equinox as eqx + import jax.numpy as jnp + import jax.random as jr -However if you were to do: + mlp = eqx.nn.MLP(1, 1, 8, 2, key=jr.PRNGKey(0)) -```python -model = ... + class VectorField(eqx.Module): + model: eqx.Module -def func(t, y, args): - return model(y) + def __call__(self, t, y, args): + return self.model(y) -term = diffrax.ODETerm(func) -diffrax.diffeqsolve(term, ..., adjoint=diffrax.BacksolveAdjoint()) -``` + @eqx.filter_jit + @eqx.filter_value_and_grad + def run(model): + f = VectorField(model) + sol = diffeqsolve(ODETerm(f), Euler(), 0, 1, 0.1, jnp.array([1.0]), adjoint=BacksolveAdjoint()) + return jnp.sum(sol.ys) -then the parameters of `model` are not visible to `diffeqsolve` so gradients will not be computed with respect to them. + run(mlp) + ``` diff --git a/test/test_adjoint.py b/test/test_adjoint.py index 93f2623e..10a8fedd 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -4,6 +4,7 @@ import equinox as eqx import jax import jax.numpy as jnp +import jax.random as jrandom import pytest from helpers import shaped_allclose @@ -141,3 +142,54 @@ def solve(y0): return jnp.sum(sol.ys) jax.grad(solve)(2.0) + + +def test_closure_errors(): + mlp = eqx.nn.MLP(1, 1, 8, 2, key=jrandom.PRNGKey(0)) + + @eqx.filter_jit + @eqx.filter_value_and_grad + def run(model): + def f(t, y, args): + return model(y) + + sol = diffrax.diffeqsolve( + diffrax.ODETerm(f), + diffrax.Euler(), + 0, + 1, + 0.1, + jnp.array([1.0]), + adjoint=diffrax.BacksolveAdjoint(), + ) + return jnp.sum(sol.ys) + + with pytest.raises(jax.interpreters.ad.CustomVJPException): + run(mlp) + + +def test_closure_fixed(): + mlp = eqx.nn.MLP(1, 1, 8, 2, key=jrandom.PRNGKey(0)) + + class VectorField(eqx.Module): + model: eqx.Module + + def __call__(self, t, y, args): + return self.model(y) + + @eqx.filter_jit + @eqx.filter_value_and_grad + def run(model): + f = VectorField(model) + sol = diffrax.diffeqsolve( + diffrax.ODETerm(f), + diffrax.Euler(), + 0, + 1, + 0.1, + jnp.array([1.0]), + adjoint=diffrax.BacksolveAdjoint(), + ) + return jnp.sum(sol.ys) + + run(mlp)