Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix jacfwd #734

Merged
merged 2 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,17 +429,13 @@ def _fun(_diff_x):
_out, _aux = _out
else:
_aux = None
_dynamic_out, _static_out = partition(
_out, lambda j: isinstance(j, ad.JVPTracer)
)
return _dynamic_out, (_static_out, _aux)
return _out, _aux

if self.rev:
jacobian = jax.jacrev
else:
jacobian = jax.jacfwd
dynamic_out, (static_out, aux) = jacobian(_fun, has_aux=True)(diff_x)
out = combine(dynamic_out, static_out)
out, aux = jacobian(_fun, has_aux=True)(diff_x)
if self.has_aux:
return out, aux
else:
Expand All @@ -460,6 +456,11 @@ def filter_jacfwd(fun, has_aux: bool = False):

A function with the same arguments as `fun`.

!!! warning

The outputs of `fun` must be jax types, the filtering is only applied
to the input not the output.

If `has_aux is False` then this function returns just the Jacobian of `fun` with
respect to its first argument.

Expand All @@ -483,6 +484,11 @@ def filter_jacrev(fun, has_aux: bool = False):

A function with the same arguments as `fun`.

!!! warning

The outputs of `fun` must be jax types, the filtering is only applied
to the input not the output.

If `has_aux is False` then this function returns just the Jacobian of `fun` with
respect to its first argument.

Expand All @@ -503,6 +509,11 @@ def filter_hessian(fun, has_aux: bool = False):

A function with the same arguments as `fun`.

!!! warning

The outputs of `fun` must be jax types, the filtering is only applied
to the input not the output.

If `has_aux is False` then this function returns just the Hessian of `fun` with
respect to its first argument.

Expand Down
35 changes: 35 additions & 0 deletions tests/test_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,41 @@ def f_bwd(res, g):
assert tree_allclose(jax_hess, eqx_hess)


def test_pytree_jacfwd():
class NeuralNetwork(eqx.Module):
layers: list
extra_bias: jax.Array

def __init__(self, key):
key1, key2, key3 = jax.random.split(key, 3)
self.layers = [
eqx.nn.Linear(2, 8, key=key1),
eqx.nn.Linear(8, 8, key=key2),
eqx.nn.Linear(8, 2, key=key3),
]
self.extra_bias = jax.numpy.ones(2)

def __call__(self, x):
for layer in self.layers[:-1]:
x = jax.nn.relu(layer(x))
return self.layers[-1](x) + self.extra_bias

def loss(model, x, y):
pred_y = jax.vmap(model)(x)
return jax.numpy.mean((y - pred_y) ** 2)

x_key, y_key, model_key = jax.random.split(jax.random.PRNGKey(0), 3)
x = jax.random.normal(x_key, (3, 2))
y = jax.random.normal(y_key, (3, 2))
model = NeuralNetwork(model_key)
assert tree_allclose(
eqx.filter_grad(loss)(model, x, y), eqx.filter_jacfwd(loss)(model, x, y)
)
assert tree_allclose(
eqx.filter_grad(loss)(model, x, y), eqx.filter_jacrev(loss)(model, x, y)
)


def test_filter_custom_jvp_symbolic_zero():
@eqx.filter_custom_jvp
def f(x, y):
Expand Down
Loading