Skip to content

Commit

Permalink
Fix jacfwd (#734)
Browse files Browse the repository at this point in the history
* fix + test

* add warning + disable static
  • Loading branch information
lockwo committed May 23, 2024
1 parent 1f7908b commit d4f6a0e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 6 deletions.
23 changes: 17 additions & 6 deletions equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,17 +429,13 @@ def _fun(_diff_x):
_out, _aux = _out
else:
_aux = None
_dynamic_out, _static_out = partition(
_out, lambda j: isinstance(j, ad.JVPTracer)
)
return _dynamic_out, (_static_out, _aux)
return _out, _aux

if self.rev:
jacobian = jax.jacrev
else:
jacobian = jax.jacfwd
dynamic_out, (static_out, aux) = jacobian(_fun, has_aux=True)(diff_x)
out = combine(dynamic_out, static_out)
out, aux = jacobian(_fun, has_aux=True)(diff_x)
if self.has_aux:
return out, aux
else:
Expand All @@ -460,6 +456,11 @@ def filter_jacfwd(fun, has_aux: bool = False):
A function with the same arguments as `fun`.
!!! warning
The outputs of `fun` must be jax types, the filtering is only applied
to the input not the output.
If `has_aux is False` then this function returns just the Jacobian of `fun` with
respect to its first argument.
Expand All @@ -483,6 +484,11 @@ def filter_jacrev(fun, has_aux: bool = False):
A function with the same arguments as `fun`.
!!! warning
The outputs of `fun` must be jax types, the filtering is only applied
to the input not the output.
If `has_aux is False` then this function returns just the Jacobian of `fun` with
respect to its first argument.
Expand All @@ -503,6 +509,11 @@ def filter_hessian(fun, has_aux: bool = False):
A function with the same arguments as `fun`.
!!! warning
The outputs of `fun` must be jax types, the filtering is only applied
to the input not the output.
If `has_aux is False` then this function returns just the Hessian of `fun` with
respect to its first argument.
Expand Down
35 changes: 35 additions & 0 deletions tests/test_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,41 @@ def f_bwd(res, g):
assert tree_allclose(jax_hess, eqx_hess)


def test_pytree_jacfwd():
class NeuralNetwork(eqx.Module):
layers: list
extra_bias: jax.Array

def __init__(self, key):
key1, key2, key3 = jax.random.split(key, 3)
self.layers = [
eqx.nn.Linear(2, 8, key=key1),
eqx.nn.Linear(8, 8, key=key2),
eqx.nn.Linear(8, 2, key=key3),
]
self.extra_bias = jax.numpy.ones(2)

def __call__(self, x):
for layer in self.layers[:-1]:
x = jax.nn.relu(layer(x))
return self.layers[-1](x) + self.extra_bias

def loss(model, x, y):
pred_y = jax.vmap(model)(x)
return jax.numpy.mean((y - pred_y) ** 2)

x_key, y_key, model_key = jax.random.split(jax.random.PRNGKey(0), 3)
x = jax.random.normal(x_key, (3, 2))
y = jax.random.normal(y_key, (3, 2))
model = NeuralNetwork(model_key)
assert tree_allclose(
eqx.filter_grad(loss)(model, x, y), eqx.filter_jacfwd(loss)(model, x, y)
)
assert tree_allclose(
eqx.filter_grad(loss)(model, x, y), eqx.filter_jacrev(loss)(model, x, y)
)


def test_filter_custom_jvp_symbolic_zero():
@eqx.filter_custom_jvp
def f(x, y):
Expand Down

0 comments on commit d4f6a0e

Please sign in to comment.